Prechádzať zdrojové kódy

send streaming as args

lyuxiang.lx 8 mesiacov pred
rodič
commit
cbfed4a9ee

+ 3 - 4
cosyvoice/cli/model.py

@@ -258,9 +258,6 @@ class CosyVoice2Model(CosyVoiceModel):
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         self.llm = llm
         self.flow = flow
-        # NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
-        self.flow.encoder.streaming = True
-        self.flow.decoder.estimator.streaming = True
         self.hift = hift
         self.fp16 = fp16
         self.trt_concurrent = trt_concurrent
@@ -290,7 +287,7 @@ class CosyVoice2Model(CosyVoiceModel):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
-    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
         with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
             tts_mel, _ = self.flow.inference(token=token.to(self.device),
                                              token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
@@ -299,6 +296,7 @@ class CosyVoice2Model(CosyVoiceModel):
                                              prompt_feat=prompt_feat.to(self.device),
                                              prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                              embedding=embedding.to(self.device),
+                                             streaming=stream,
                                              finalize=finalize)
         tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
         # append hift cache
@@ -356,6 +354,7 @@ class CosyVoice2Model(CosyVoiceModel):
                                                      embedding=flow_embedding,
                                                      token_offset=token_offset,
                                                      uuid=this_uuid,
+                                                     stream=stream,
                                                      finalize=False)
                     token_offset += this_token_hop_len
                     yield {'tts_speech': this_tts_speech.cpu()}

+ 0 - 4
cosyvoice/flow/decoder.py

@@ -419,10 +419,6 @@ class CausalConditionalDecoder(ConditionalDecoder):
         Returns:
             _type_: _description_
         """
-        if hasattr(self, 'streaming'):
-            assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
-            streaming = self.streaming
-
         t = self.time_embeddings(t).to(t.dtype)
         t = self.time_mlp(t)
 

+ 4 - 2
cosyvoice/flow/flow.py

@@ -241,6 +241,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
                   prompt_feat,
                   prompt_feat_len,
                   embedding,
+                  streaming,
                   finalize):
         assert token.shape[0] == 1
         # xvec projection
@@ -254,10 +255,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
 
         # text encode
         if finalize is True:
-            h, h_lengths = self.encoder(token, token_len)
+            h, h_lengths = self.encoder(token, token_len, streaming=streaming)
         else:
             token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
-            h, h_lengths = self.encoder(token, token_len, context=context)
+            h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
         mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
         h = self.encoder_proj(h)
 
@@ -273,6 +274,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
             spks=embedding,
             cond=conds,
             n_timesteps=10,
+            streaming=streaming
         )
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2

+ 7 - 6
cosyvoice/flow/flow_matching.py

@@ -69,7 +69,7 @@ class ConditionalCFM(BASECFM):
             t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
         return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
 
-    def solve_euler(self, x, t_span, mu, mask, spks, cond):
+    def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
         """
         Fixed euler solver for ODEs.
         Args:
@@ -110,7 +110,8 @@ class ConditionalCFM(BASECFM):
                 x_in, mask_in,
                 mu_in, t_in,
                 spks_in,
-                cond_in
+                cond_in,
+                streaming
             )
             dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
             dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
@@ -122,9 +123,9 @@ class ConditionalCFM(BASECFM):
 
         return sol[-1].float()
 
-    def forward_estimator(self, x, mask, mu, t, spks, cond):
+    def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
         if isinstance(self.estimator, torch.nn.Module):
-            return self.estimator(x, mask, mu, t, spks, cond)
+            return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
         else:
             estimator, trt_engine = self.estimator.acquire_estimator()
             estimator.set_input_shape('x', (2, 80, x.size(2)))
@@ -196,7 +197,7 @@ class CausalConditionalCFM(ConditionalCFM):
         self.rand_noise = torch.randn([1, 80, 50 * 300])
 
     @torch.inference_mode()
-    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
+    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
         """Forward diffusion
 
         Args:
@@ -220,4 +221,4 @@ class CausalConditionalCFM(ConditionalCFM):
         t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         if self.t_scheduler == 'cosine':
             t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
-        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
+        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None

+ 0 - 3
cosyvoice/transformer/upsample_encoder.py

@@ -272,9 +272,6 @@ class UpsampleConformerEncoder(torch.nn.Module):
             checkpointing API because `__call__` attaches all the hooks of the module.
             https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
         """
-        if hasattr(self, 'streaming'):
-            assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
-            streaming = self.streaming
         T = xs.size(1)
         masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
         if self.global_cmvn is not None: