1
0
Selaa lähdekoodia

[debug] handle cache with prompt

boji123 1 vuosi sitten
vanhempi
commit
8130abb5ea
2 muutettua tiedostoa jossa 4 lisäystä ja 3 poistoa
  1. 1 0
      cosyvoice/flow/flow.py
  2. 3 3
      cosyvoice/flow/flow_matching.py

+ 1 - 0
cosyvoice/flow/flow.py

@@ -141,6 +141,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
             spks=embedding,
             cond=conds,
             n_timesteps=10,
+            prompt_len=mel_len1,
             required_cache_size=required_cache_size,
             flow_cache=flow_cache
         )

+ 3 - 3
cosyvoice/flow/flow_matching.py

@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
         self.estimator = estimator
 
     @torch.inference_mode()
-    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None):
+    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, required_cache_size=0, flow_cache=None):
         """Forward diffusion
 
         Args:
@@ -62,8 +62,8 @@ class ConditionalCFM(BASECFM):
 
         next_cache_start = max(z.size(2) - required_cache_size, 0)
         flow_cache = [
-            z[..., next_cache_start:],
-            mu[..., next_cache_start:]
+            torch.cat((z[..., :prompt_len], z[..., next_cache_start:]), dim=2),
+            torch.cat((mu[..., :prompt_len], mu[..., next_cache_start:]), dim=2)
         ]
 
         t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)