lyuxiang.lx 1 mesiac pred
rodič
commit
e15222b17c
1 zmenil súbory, kde vykonal 19 pridanie a 25 odobranie
  1. 19 25
      cosyvoice/cli/model.py

+ 19 - 25
cosyvoice/cli/model.py

@@ -103,35 +103,29 @@ class CosyVoiceModel:
         with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
             if isinstance(text, Generator):
                 assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
-                for i in self.llm.inference_bistream(text=text,
+                token_generator = self.llm.inference_bistream(text=text,
+                                                              prompt_text=prompt_text.to(self.device),
+                                                              prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                                              prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                                              prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                              embedding=llm_embedding.to(self.device))
+            else:
+                token_generator = self.llm.inference(text=text.to(self.device),
+                                                     text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
                                                      prompt_text=prompt_text.to(self.device),
                                                      prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                                      prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                                      prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                                     embedding=llm_embedding.to(self.device)):
-                    if i in self.silent_tokens:
-                        cur_silent_token_num += 1
-                        if cur_silent_token_num > max_silent_token_num:
-                            continue
-                    else:
-                        cur_silent_token_num = 0
-                    self.tts_speech_token_dict[uuid].append(i)
-            else:
-                for i in self.llm.inference(text=text.to(self.device),
-                                            text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
-                                            prompt_text=prompt_text.to(self.device),
-                                            prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
-                                            prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                            prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                            embedding=llm_embedding.to(self.device),
-                                            uuid=uuid):
-                    if i in self.silent_tokens:
-                        cur_silent_token_num += 1
-                        if cur_silent_token_num > max_silent_token_num:
-                            continue
-                    else:
-                        cur_silent_token_num = 0
-                    self.tts_speech_token_dict[uuid].append(i)
+                                                     embedding=llm_embedding.to(self.device),
+                                                     uuid=uuid)  
+            for i in token_generator:
+                if i in self.silent_tokens:
+                    cur_silent_token_num += 1
+                    if cur_silent_token_num > max_silent_token_num:
+                        continue
+                else:
+                    cur_silent_token_num = 0
+                self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
     def vc_job(self, source_speech_token, uuid):