Jelajahi Sumber

update token args

lyuxiang.lx 1 tahun lalu
induk
melakukan
ffa28e3bbd
2 mengubah file dengan 4 tambahan dan 7 penghapusan
  1. 3 6
      cosyvoice/cli/model.py
  2. 1 1
      cosyvoice/llm/llm.py

+ 3 - 6
cosyvoice/cli/model.py

@@ -31,8 +31,8 @@ class CosyVoiceModel:
         self.llm = llm
         self.flow = flow
         self.hift = hift
-        self.token_min_hop_len = 100
-        self.token_max_hop_len = 200
+        self.token_min_hop_len = 2 * self.flow.input_frame_rate
+        self.token_max_hop_len = 4 * self.flow.input_frame_rate
         self.token_overlap_len = 20
         # mel fade in out
         self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
@@ -87,10 +87,7 @@ class CosyVoiceModel:
                                         prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                         prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                         prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                        embedding=llm_embedding.to(self.device).half(),
-                                        sampling=25,
-                                        max_token_text_ratio=30,
-                                        min_token_text_ratio=3):
+                                        embedding=llm_embedding.to(self.device).half()):
                 self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 

+ 1 - 1
cosyvoice/llm/llm.py

@@ -197,7 +197,7 @@ class TransformerLM(torch.nn.Module):
         offset = 0
         att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
         for i in range(max_len):
-            y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1,
+            y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
                                                                   att_cache=att_cache, cnn_cache=cnn_cache,
                                                                   att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
                                                                                                  device=lm_input.device)).to(torch.bool))