Răsfoiți Sursa

fix vocoder speech overlap

lyuxiang.lx 1 an în urmă
părinte
comite
1d881df8b2
3 a modificat fișierele cu 93 adăugiri și 74 ștergeri
  1. 79 66
      cosyvoice/cli/model.py
  2. 8 4
      cosyvoice/hifigan/generator.py
  3. 6 4
      cosyvoice/utils/common.py

+ 79 - 66
cosyvoice/cli/model.py

@@ -31,18 +31,25 @@ class CosyVoiceModel:
         self.flow = flow
         self.hift = hift
         self.token_min_hop_len = 100
-        self.token_max_hop_len = 400
+        self.token_max_hop_len = 200
         self.token_overlap_len = 20
-        self.speech_overlap_len = 34 * 256
-        self.window = np.hamming(2 * self.speech_overlap_len)
+        # mel fade in out
+        self.mel_overlap_len = 34
+        self.mel_window = np.hamming(2 * self.mel_overlap_len)
+        # hift cache
+        self.mel_cache_len = 20
+        self.source_cache_len = int(self.mel_cache_len * 256)
+        # rtf and decoding related
         self.stream_scale_factor = 1
         assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.lock = threading.Lock()
         # dict used to store session related variable
-        self.tts_speech_token = {}
-        self.llm_end = {}
+        self.tts_speech_token_dict = {}
+        self.llm_end_dict = {}
+        self.mel_overlap_dict = {}
+        self.hift_cache_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -64,102 +71,108 @@ class CosyVoiceModel:
         self.flow.decoder.estimator = xxx
         self.flow.decoder.session = xxx
 
-    def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
+    def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),
-                                                text_len=text_len.to(self.device),
+                                                text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
                                                 prompt_text=prompt_text.to(self.device),
-                                                prompt_text_len=prompt_text_len.to(self.device),
+                                                prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                                 prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                                prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
+                                                prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                                 embedding=llm_embedding.to(self.device).half(),
                                                 sampling=25,
                                                 max_token_text_ratio=30,
                                                 min_token_text_ratio=3):
-                self.tts_speech_token[this_uuid].append(i)
-        self.llm_end[this_uuid] = True
+                self.tts_speech_token_dict[uuid].append(i)
+        self.llm_end_dict[uuid] = True
 
-    def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
         with self.flow_hift_context:
             tts_mel = self.flow.inference(token=token.to(self.device),
-                                        token_len=torch.tensor([token.size(1)], dtype=torch.int32).to(self.device),
+                                        token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                         prompt_token=prompt_token.to(self.device),
-                                        prompt_token_len=prompt_token_len.to(self.device),
+                                        prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                         prompt_feat=prompt_feat.to(self.device),
-                                        prompt_feat_len=prompt_feat_len.to(self.device),
+                                        prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                         embedding=embedding.to(self.device))
-            tts_speech = self.hift.inference(mel=tts_mel).cpu()
+            # mel overlap fade in out
+            if self.mel_overlap_dict[uuid] is not None:
+                tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
+            # append hift cache
+            if self.hift_cache_dict[uuid] is not None:
+                hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+                tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+            else:
+                hift_cache_source = torch.zeros(1, 1, 0)
+            # keep overlap mel and hift cache
+            if finalize is False:
+                self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
+                tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
+                tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
+                self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
+                tts_speech = tts_speech[:, :-self.source_cache_len]
+            else:
+                tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
         return tts_speech
 
-    def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
-                  prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
-                  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
-                  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
-                  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
+    def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
+                  prompt_text=torch.zeros(1, 0, dtype=torch.int32),
+                  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+                  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+                  prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
         with self.lock:
-            self.tts_speech_token[this_uuid], self.llm_end[this_uuid] = [], False
-        p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
-                                                    llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device), this_uuid))
+            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
+        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         p.start()
+        p.join()
         if stream is True:
-            cache_speech, cache_token, token_hop_len = None, None, self.token_min_hop_len
+            token_hop_len = self.token_min_hop_len
             while True:
                 time.sleep(0.1)
-                if len(self.tts_speech_token[this_uuid]) >= token_hop_len + self.token_overlap_len:
-                    this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
+                if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
+                    this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
                     with self.flow_hift_context:
                         this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                                    prompt_token=flow_prompt_speech_token.to(self.device),
-                                                    prompt_token_len=flow_prompt_speech_token_len.to(self.device),
-                                                    prompt_feat=prompt_speech_feat.to(self.device),
-                                                    prompt_feat_len=prompt_speech_feat_len.to(self.device),
-                                                    embedding=flow_embedding.to(self.device))
-                    # fade in/out if necessary
-                    if cache_speech is not None:
-                        this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
-                    yield  {'tts_speech': this_tts_speech[:, :-self.speech_overlap_len]}
-                    cache_speech = this_tts_speech[:, -self.speech_overlap_len:]
-                    cache_token = self.tts_speech_token[this_uuid][:token_hop_len]
+                                                    prompt_token=flow_prompt_speech_token,
+                                                    prompt_feat=prompt_speech_feat,
+                                                    embedding=flow_embedding,
+                                                    uuid=this_uuid,
+                                                    finalize=False)
+                    yield  {'tts_speech': this_tts_speech.cpu()}
                     with self.lock:
-                        self.tts_speech_token[this_uuid] = self.tts_speech_token[this_uuid][token_hop_len:]
+                        self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
                     # increase token_hop_len for better speech quality
                     token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
-                if self.llm_end[this_uuid] is True and len(self.tts_speech_token[this_uuid]) < token_hop_len + self.token_overlap_len:
+                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
                     break
-            p.join()
+            # p.join()
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
-            this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
-            if this_tts_speech_token.shape[1] < self.token_min_hop_len + self.token_overlap_len and cache_token is not None:
-                cache_token_len = self.token_min_hop_len + self.token_overlap_len - this_tts_speech_token.shape[1]
-                this_tts_speech_token = torch.concat([torch.concat(cache_token[-cache_token_len:], dim=1), this_tts_speech_token], dim=1)
-            else:
-                cache_token_len = 0
+            this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
                 this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                            prompt_token=flow_prompt_speech_token.to(self.device),
-                                            prompt_token_len=flow_prompt_speech_token_len.to(self.device),
-                                            prompt_feat=prompt_speech_feat.to(self.device),
-                                            prompt_feat_len=prompt_speech_feat_len.to(self.device),
-                                            embedding=flow_embedding.to(self.device))
-                this_tts_speech = this_tts_speech[:, int(cache_token_len / this_tts_speech_token.shape[1] * this_tts_speech.shape[1]):]
-            if cache_speech is not None:
-                this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
-            yield {'tts_speech': this_tts_speech}
+                                            prompt_token=flow_prompt_speech_token,
+                                            prompt_feat=prompt_speech_feat,
+                                            embedding=flow_embedding,
+                                            uuid=this_uuid,
+                                            finalize=True)
+            yield {'tts_speech': this_tts_speech.cpu()}
         else:
             # deal with all tokens
-            p.join()
-            this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
+            # p.join()
+            this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
                 this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                            prompt_token=flow_prompt_speech_token.to(self.device),
-                                            prompt_token_len=flow_prompt_speech_token_len.to(self.device),
-                                            prompt_feat=prompt_speech_feat.to(self.device),
-                                            prompt_feat_len=prompt_speech_feat_len.to(self.device),
-                                            embedding=flow_embedding.to(self.device))
-            yield {'tts_speech': this_tts_speech}
+                                            prompt_token=flow_prompt_speech_token,
+                                            prompt_feat=prompt_speech_feat,
+                                            embedding=flow_embedding,
+                                            uuid=this_uuid,
+                                            finalize=True)
+            yield {'tts_speech': this_tts_speech.cpu()}
         with self.lock:
-            self.tts_speech_token.pop(this_uuid)
-            self.llm_end.pop(this_uuid)
+            self.tts_speech_token_dict.pop(this_uuid)
+            self.llm_end_dict.pop(this_uuid)
+            self.mel_overlap_dict.pop(this_uuid)
+            self.hift_cache_dict.pop(this_uuid)
         torch.cuda.synchronize()

+ 8 - 4
cosyvoice/hifigan/generator.py

@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
         inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
         return inverse_transform
 
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
         f0 = self.f0_predictor(x)
         s = self._f02source(f0)
 
+        # use cache_source to avoid glitch
+        if cache_source.shape[2] == 0:
+            s[:, :, :cache_source.shape[2]] = cache_source
+
         s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
         s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
 
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
 
         x = self._istft(magnitude, phase)
         x = torch.clamp(x, -self.audio_limit, self.audio_limit)
-        return x
+        return x, s
 
     def remove_weight_norm(self):
         print('Removing weight norm...')
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
             l.remove_weight_norm()
 
     @torch.inference_mode()
-    def inference(self, mel: torch.Tensor) -> torch.Tensor:
-        return self.forward(x=mel)
+    def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+        return self.forward(x=mel, cache_source=cache_source)

+ 6 - 4
cosyvoice/utils/common.py

@@ -131,7 +131,9 @@ def random_sampling(weighted_scores, decoded_tokens, sampling):
     top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
     return top_ids
 
-def fade_in_out(fade_in_speech, fade_out_speech, window):
-    speech_overlap_len = int(window.shape[0] / 2)
-    fade_in_speech[:, :speech_overlap_len] = fade_in_speech[:, :speech_overlap_len] * window[:speech_overlap_len] + fade_out_speech[:, -speech_overlap_len:] * window[speech_overlap_len:]
-    return fade_in_speech
+def fade_in_out(fade_in_mel, fade_out_mel, window):
+    device = fade_in_mel.device
+    fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
+    mel_overlap_len = int(window.shape[0] / 2)
+    fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
+    return fade_in_mel.to(device)