Browse Source

simply code

lyuxiang.lx 2 weeks ago
parent
commit
0e624898a1
3 changed files with 62 additions and 118 deletions
  1. 0 2
      cosyvoice/cli/cosyvoice.py
  2. 3 3
      cosyvoice/cli/model.py
  3. 59 113
      cosyvoice/llm/llm.py

+ 0 - 2
cosyvoice/cli/cosyvoice.py

@@ -89,8 +89,6 @@ class CosyVoice:
                 start_time = time.time()
 
     def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
-        if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
-            logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
         prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
             if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):

+ 3 - 3
cosyvoice/cli/model.py

@@ -117,7 +117,7 @@ class CosyVoiceModel:
                                                      prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                                      prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                                      embedding=llm_embedding.to(self.device),
-                                                     uuid=uuid)  
+                                                     uuid=uuid)
             for i in token_generator:
                 if i in self.silent_tokens:
                     cur_silent_token_num += 1
@@ -256,7 +256,7 @@ class CosyVoice2Model(CosyVoiceModel):
         self.fp16 = fp16
         # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
-        # NOTE increase token_hop_len incrementally to avoid duplicate inference 
+        # NOTE increase token_hop_len incrementally to avoid duplicate inference
         self.token_max_hop_len = 4 * self.token_hop_len
         self.stream_scale_factor = 2
         assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
@@ -408,7 +408,7 @@ class CosyVoice3Model(CosyVoice2Model):
         self.fp16 = fp16
         # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
-        # NOTE increase token_hop_len incrementally to avoid duplicate inference 
+        # NOTE increase token_hop_len incrementally to avoid duplicate inference
         self.token_max_hop_len = 4 * self.token_hop_len
         self.stream_scale_factor = 2
         assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'

+ 59 - 113
cosyvoice/llm/llm.py

@@ -154,14 +154,9 @@ class TransformerLM(torch.nn.Module):
             sampling: int,
             ignore_eos: bool = True,
     ):
-        num_trials, max_trials = 0, 100
-        while True:
-            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
-            if (not ignore_eos) or (top_ids < self.speech_token_size):
-                break
-            num_trials += 1
-            if num_trials > max_trials:
-                raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
+        if ignore_eos is True:
+            weighted_scores[self.speech_token_size] = -float('inf')
+        top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
         return top_ids
 
     @torch.inference_mode()
@@ -365,34 +360,48 @@ class Qwen2LM(TransformerLM):
             audio: (B, T, N) or (B, T)
             audio_lengths: (B,)
         """
+        # 1. encode text_token
         text_token = batch['text_token'].to(device)
         text_token_len = batch['text_token_len'].to(device)
+        text_token_emb = self.llm.model.model.embed_tokens(text_token)
+
+        # 2. encode speech_token
         if 'speech_token' not in batch:
             speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
         else:
             speech_token = batch['speech_token'].to(device)
             speech_token_len = batch['speech_token_len'].to(device)
-
-        # 1. encode text_token
-        text_token_emb = self.llm.model.model.embed_tokens(text_token)
-
-        # 3. sos and task_id
-        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
-        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
-
-        # 2. encode speech_token
         speech_token_emb = self.speech_embedding(speech_token)
 
-        # 3. prepare llm_input/target
-        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
-                                                                         speech_token, speech_token_emb, speech_token_len)
+        # 3. sos and task_id
+        if self.__class__.__name__ == 'CosyVoice3LM':
+            sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
+            task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
+        elif self.__class__.__name__ == 'Qwen2LM':
+            sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
+            task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+        else:
+            raise ValueError
+
+        # 4. prepare llm_input/target
+        if self.__class__.__name__ == 'CosyVoice3LM':
+            instruct_token = batch['instruct_token'].to(device)
+            instruct_token_len = batch['instruct_token_len'].to(device)
+            instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
+            lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
+                                                                             speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
+        elif self.__class__.__name__ == 'Qwen2LM':
+            lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
+                                                                             speech_token, speech_token_emb, speech_token_len)
+        else:
+            raise ValueError
         lm_target = lm_target.to(device)
 
         # 4. run lm forward
         lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
         logits = self.llm_decoder(lm_output)
         loss = self.criterion_ce(logits, lm_target.to(device))
-        acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
+        acc = th_accuracy(logits.view(-1, self.llm_decoder.out_features), lm_target, ignore_label=IGNORE_ID)
         return {'loss': loss, 'acc': acc}
 
     def forward_dpo(
@@ -464,16 +473,25 @@ class Qwen2LM(TransformerLM):
         device = text.device
         text = torch.concat([prompt_text, text], dim=1)
         text_len += prompt_text_len
-        text = self.llm.model.model.embed_tokens(text)
+        text_emb = self.llm.model.model.embed_tokens(text)
+        if self.__class__.__name__ == 'CosyVoice3LM':
+            # NOTE temporary hardcode, 151646 is <|endofprompt|> token
+            assert 151646 in text, '<|endofprompt|> not detected in CosyVoice3 text or prompt_text, check your input!'
 
         # 3. concat llm_input
-        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
-        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+        if self.__class__.__name__ == 'CosyVoice3LM':
+            sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
+            task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
+        elif self.__class__.__name__ == 'Qwen2LM':
+            sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
+            task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+        else:
+            raise ValueError
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
-            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
+            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text_emb.dtype).to(device)
+        lm_input = torch.concat([sos_emb, text_emb, task_id_emb, prompt_speech_token_emb], dim=1)
 
         # 4. cal min/max_length
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -546,8 +564,14 @@ class Qwen2LM(TransformerLM):
 
         device = prompt_text.device
         # 1. prepare input
-        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
-        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+        if self.__class__.__name__ == 'CosyVoice3LM':
+            sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
+            task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
+        elif self.__class__.__name__ == 'Qwen2LM':
+            sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
+            task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+        else:
+            raise ValueError
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
@@ -558,6 +582,12 @@ class Qwen2LM(TransformerLM):
         out_tokens = []
         cache = None
         # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
+        if self.__class__.__name__ == 'CosyVoice3LM':
+            # NOTE temporary hardcode, 151646 is <|endofprompt|> token
+            assert 151646 in prompt_text, '<|endofprompt|> not detected in CosyVoice3 prompt_text, check your input!'
+            eop_index = prompt_text.flatten().tolist().index(151646)
+            lm_input = torch.concat([lm_input, self.llm.model.model.embed_tokens(prompt_text[:, :eop_index + 1])], dim=1)
+            prompt_text = prompt_text[:, eop_index + 1:]
         text_cache = self.llm.model.model.embed_tokens(prompt_text)
         next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
         for this_text in text:
@@ -673,88 +703,4 @@ class CosyVoice3LM(Qwen2LM):
         self.stop_token_ids = [speech_token_size + i for i in range(200)]
         self.vllm_output_queue = {}
         if online_feature is True:
-            self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
-
-    def forward(
-            self,
-            batch: dict,
-            device: torch.device,
-    ) -> Dict[str, Optional[torch.Tensor]]:
-        """
-        Args:
-            text: (B, L, D)
-            text_lengths: (B,)
-            audio: (B, T, N) or (B, T)
-            audio_lengths: (B,)
-        """
-        text_token = batch['text_token'].to(device)
-        text_token_len = batch['text_token_len'].to(device)
-        if 'speech_token' not in batch:
-            speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
-        else:
-            speech_token = batch['speech_token'].to(device)
-            speech_token_len = batch['speech_token_len'].to(device)
-
-        # NOTE should append instruct_token to sequence, not implemented yet
-        instruct_token = batch['instruct_token'].to(device)
-        instruct_token_len = batch['instruct_token_len'].to(device)
-
-        # 1. encode text_token
-        text_token_emb = self.llm.model.model.embed_tokens(text_token)
-        instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
-
-        # 3. sos and task_id
-        sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
-        task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
-
-        # 2. encode speech_token
-        speech_token_emb = self.speech_embedding(speech_token)
-
-        # 3. prepare llm_input/target
-        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
-                                                                         speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
-        lm_target = lm_target.to(device)
-
-        # 4. run lm forward
-        lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
-        logits = self.llm_decoder(lm_output)
-        loss = self.criterion_ce(logits, lm_target.to(device))
-        acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
-        return {'loss': loss, 'acc': acc}
-
-    @torch.inference_mode()
-    def inference(
-            self,
-            text: torch.Tensor,
-            text_len: torch.Tensor,
-            prompt_text: torch.Tensor,
-            prompt_text_len: torch.Tensor,
-            prompt_speech_token: torch.Tensor,
-            prompt_speech_token_len: torch.Tensor,
-            embedding: torch.Tensor,
-            sampling: int = 25,
-            max_token_text_ratio: float = 20,
-            min_token_text_ratio: float = 2,
-            uuid: str = '',
-    ) -> Generator[torch.Tensor, None, None]:
-        device = text.device
-        text = torch.concat([prompt_text, text], dim=1)
-        text_len += prompt_text_len
-        text = self.llm.model.model.embed_tokens(text)
-
-        # 3. concat llm_input
-        sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
-        task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
-        if prompt_speech_token_len != 0:
-            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
-        else:
-            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
-
-        # 4. cal min/max_length
-        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
-        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
-
-        # 5. step by step decode
-        for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
-            yield token
+            self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))