Prechádzať zdrojové kódy

Merge pull request #379 from boji123/bj_dev_stream_fix

[debug] fix badcase, add fade on speech output
Xiang Lyu 1 rok pred
rodič
commit
cd26f11859
2 zmenil súbory, kde vykonal 12 pridanie a 3 odobranie
  1. 9 1
      cosyvoice/cli/model.py
  2. 3 2
      cosyvoice/utils/common.py

+ 9 - 1
cosyvoice/cli/model.py

@@ -50,6 +50,7 @@ class CosyVoiceModel:
         self.llm_end_dict = {}
         self.mel_overlap_dict = {}
         self.hift_cache_dict = {}
+        self.speech_window = np.hamming(2 * self.source_cache_len)
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -114,13 +115,20 @@ class CosyVoiceModel:
             self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
             tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
             tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
-            self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+            self.hift_cache_dict[uuid] = {
+                'mel': tts_mel[:, :, -self.mel_cache_len:],
+                'source': tts_source[:, :, -self.source_cache_len:],
+                'speech': tts_speech[:, -self.source_cache_len:]}
             tts_speech = tts_speech[:, :-self.source_cache_len]
         else:
             if speed != 1.0:
                 assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                 tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
             tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
         return tts_speech
 
     def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),

+ 3 - 2
cosyvoice/utils/common.py

@@ -139,6 +139,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
     device = fade_in_mel.device
     fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
     mel_overlap_len = int(window.shape[0] / 2)
-    fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \
-        fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
+
+    fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
+        fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
     return fade_in_mel.to(device)