Browse Source

add speech fade in out

lyuxiang.lx 1 năm trước cách đây
mục cha
commit
f65eca6723
2 tập tin đã thay đổi với 5 bổ sung6 xóa
  1. 5 5
      cosyvoice/cli/model.py
  2. 0 1
      cosyvoice/utils/common.py

+ 5 - 5
cosyvoice/cli/model.py

@@ -40,6 +40,8 @@ class CosyVoiceModel:
         # hift cache
         self.mel_cache_len = 20
         self.source_cache_len = int(self.mel_cache_len * 256)
+        # speech fade in out
+        self.speech_window = np.hamming(2 * self.source_cache_len)
         # rtf and decoding related
         self.stream_scale_factor = 1
         assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
@@ -50,7 +52,6 @@ class CosyVoiceModel:
         self.llm_end_dict = {}
         self.mel_overlap_dict = {}
         self.hift_cache_dict = {}
-        self.speech_window = np.hamming(2 * self.source_cache_len)
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -117,10 +118,9 @@ class CosyVoiceModel:
             tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
             if self.hift_cache_dict[uuid] is not None:
                 tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
-            self.hift_cache_dict[uuid] = {
-                'mel': tts_mel[:, :, -self.mel_cache_len:],
-                'source': tts_source[:, :, -self.source_cache_len:],
-                'speech': tts_speech[:, -self.source_cache_len:]}
+            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
+                                          'source': tts_source[:, :, -self.source_cache_len:],
+                                          'speech': tts_speech[:, -self.source_cache_len:]}
             tts_speech = tts_speech[:, :-self.source_cache_len]
         else:
             if speed != 1.0:

+ 0 - 1
cosyvoice/utils/common.py

@@ -139,7 +139,6 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
     device = fade_in_mel.device
     fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
     mel_overlap_len = int(window.shape[0] / 2)
-
     fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
         fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
     return fade_in_mel.to(device)