lyuxiang.lx 1 hónapja
szülő
commit
cfa1c115b2
1 módosított fájl, 17 hozzáadás és 0 törlés
  1. 17 0
      cosyvoice/cli/model.py

+ 17 - 0
cosyvoice/cli/model.py

@@ -60,6 +60,7 @@ class CosyVoiceModel:
         self.mel_overlap_dict = {}
         self.flow_cache_dict = {}
         self.hift_cache_dict = {}
+        self.silent_tokens = []
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -98,6 +99,7 @@ class CosyVoiceModel:
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
 
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
+        cur_silent_token_num, max_silent_token_num = 0, 5
         with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
             if isinstance(text, Generator):
                 assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
@@ -107,6 +109,12 @@ class CosyVoiceModel:
                                                      prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                                      prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                                      embedding=llm_embedding.to(self.device)):
+                    if i in self.silent_tokens:
+                        cur_silent_token_num += 1
+                        if cur_silent_token_num > max_silent_token_num:
+                            continue
+                    else:
+                        cur_silent_token_num = 0
                     self.tts_speech_token_dict[uuid].append(i)
             else:
                 for i in self.llm.inference(text=text.to(self.device),
@@ -117,6 +125,12 @@ class CosyVoiceModel:
                                             prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                             embedding=llm_embedding.to(self.device),
                                             uuid=uuid):
+                    if i in self.silent_tokens:
+                        cur_silent_token_num += 1
+                        if cur_silent_token_num > max_silent_token_num:
+                            continue
+                    else:
+                        cur_silent_token_num = 0
                     self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
@@ -260,6 +274,7 @@ class CosyVoice2Model(CosyVoiceModel):
         self.tts_speech_token_dict = {}
         self.llm_end_dict = {}
         self.hift_cache_dict = {}
+        self.silent_tokens = []
 
     def load_jit(self, flow_encoder_model):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
@@ -401,6 +416,8 @@ class CosyVoice3Model(CosyVoice2Model):
         self.tts_speech_token_dict = {}
         self.llm_end_dict = {}
         self.hift_cache_dict = {}
+        # FSQ silent token
+        self.silent_tokens = [28, 29]
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
         with torch.cuda.amp.autocast(self.fp16):