lyuxiang.lx 5 maanden geleden
bovenliggende
commit
6b5eef62cc
1 gewijzigde bestanden met toevoegingen van 10 en 24 verwijderingen
  1. 10 24
      cosyvoice/cli/model.py

+ 10 - 24
cosyvoice/cli/model.py

@@ -403,11 +403,6 @@ class CosyVoice3Model(CosyVoice2Model):
             self.flow.half()
         # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
-        # hift cache
-        self.mel_cache_len = 8
-        self.source_cache_len = int(self.mel_cache_len * 480)
-        # speech fade in out
-        self.speech_window = np.hamming(2 * self.source_cache_len)
         # rtf and decoding related
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.lock = threading.Lock()
@@ -428,26 +423,17 @@ class CosyVoice3Model(CosyVoice2Model):
                                              streaming=stream,
                                              finalize=finalize)
         tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
-        # append hift cache
+        # append mel cache
         if self.hift_cache_dict[uuid] is not None:
-            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+            hift_cache_mel = self.hift_cache_dict[uuid]['mel']
             tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+            self.hift_cache_dict[uuid]['mel'] = tts_mel
         else:
-            hift_cache_source = torch.zeros(1, 1, 0)
-        # keep overlap mel and hift cache
-        if finalize is False:
-            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
-            if self.hift_cache_dict[uuid] is not None:
-                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
-            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
-                                          'source': tts_source[:, :, -self.source_cache_len:],
-                                          'speech': tts_speech[:, -self.source_cache_len:]}
-            tts_speech = tts_speech[:, :-self.source_cache_len]
-        else:
-            if speed != 1.0:
-                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
-                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
-            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel)
-            if self.hift_cache_dict[uuid] is not None:
-                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+            self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
+        if speed != 1.0:
+            assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
+            tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
+        tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
+        tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
+        self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
         return tts_speech