Explorar el Código

Merge pull request #497 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
Xiang Lyu hace 1 año
padre
commit
0f19b97c5a
Se han modificado 3 ficheros con 37 adiciones y 16 borrados
  1. 18 10
      cosyvoice/cli/model.py
  2. 7 4
      cosyvoice/flow/flow.py
  3. 12 2
      cosyvoice/flow/flow_matching.py

+ 18 - 10
cosyvoice/cli/model.py

@@ -53,6 +53,7 @@ class CosyVoiceModel:
         self.tts_speech_token_dict = {}
         self.llm_end_dict = {}
         self.mel_overlap_dict = {}
+        self.flow_cache_dict = {}
         self.hift_cache_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
@@ -100,15 +101,18 @@ class CosyVoiceModel:
         self.llm_end_dict[uuid] = True
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
-        tts_mel = self.flow.inference(token=token.to(self.device),
-                                      token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
-                                      prompt_token=prompt_token.to(self.device),
-                                      prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
-                                      prompt_feat=prompt_feat.to(self.device),
-                                      prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
-                                      embedding=embedding.to(self.device))
+        tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
+                                                  token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                                  prompt_token=prompt_token.to(self.device),
+                                                  prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                  prompt_feat=prompt_feat.to(self.device),
+                                                  prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                                  embedding=embedding.to(self.device),
+                                                  flow_cache=self.flow_cache_dict[uuid])
+        self.flow_cache_dict[uuid] = flow_cache
+
         # mel overlap fade in out
-        if self.mel_overlap_dict[uuid] is not None:
+        if self.mel_overlap_dict[uuid].shape[2] != 0:
             tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
         # append hift cache
         if self.hift_cache_dict[uuid] is not None:
@@ -145,7 +149,9 @@ class CosyVoiceModel:
         this_uuid = str(uuid.uuid1())
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
-            self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
+            self.hift_cache_dict[this_uuid] = None
+            self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
+            self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
         p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         p.start()
         if stream is True:
@@ -201,7 +207,9 @@ class CosyVoiceModel:
         this_uuid = str(uuid.uuid1())
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
-            self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
+            self.hift_cache_dict[this_uuid] = None
+            self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
+            self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
         if stream is True:
             token_hop_len = self.token_min_hop_len
             while True:

+ 7 - 4
cosyvoice/flow/flow.py

@@ -109,7 +109,8 @@ class MaskedDiffWithXvec(torch.nn.Module):
                   prompt_token_len,
                   prompt_feat,
                   prompt_feat_len,
-                  embedding):
+                  embedding,
+                  flow_cache):
         assert token.shape[0] == 1
         # xvec projection
         embedding = F.normalize(embedding, dim=1)
@@ -133,13 +134,15 @@ class MaskedDiffWithXvec(torch.nn.Module):
         conds = conds.transpose(1, 2)
 
         mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
-        feat = self.decoder(
+        feat, flow_cache = self.decoder(
             mu=h.transpose(1, 2).contiguous(),
             mask=mask.unsqueeze(1),
             spks=embedding,
             cond=conds,
-            n_timesteps=10
+            n_timesteps=10,
+            prompt_len=mel_len1,
+            flow_cache=flow_cache
         )
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2
-        return feat
+        return feat, flow_cache

+ 12 - 2
cosyvoice/flow/flow_matching.py

@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
         self.estimator = estimator
 
     @torch.inference_mode()
-    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
+    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
         """Forward diffusion
 
         Args:
@@ -50,11 +50,21 @@ class ConditionalCFM(BASECFM):
             sample: generated mel-spectrogram
                 shape: (batch_size, n_feats, mel_timesteps)
         """
+
         z = torch.randn_like(mu) * temperature
+        cache_size = flow_cache.shape[2]
+        # fix prompt and overlap part mu and z
+        if cache_size != 0:
+            z[:, :, :cache_size] = flow_cache[:, :, :, 0]
+            mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
+        z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
+        mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
+        flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
+
         t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         if self.t_scheduler == 'cosine':
             t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
-        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
+        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
 
     def solve_euler(self, x, t_span, mu, mask, spks, cond):
         """