lyuxiang.lx hai 8 meses
pai
achega
dc96e4c984
Modificáronse 3 ficheiros con 45 adicións e 43 borrados
  1. 1 1
      cosyvoice/cli/cosyvoice.py
  2. 40 38
      cosyvoice/llm/llm.py
  3. 4 4
      cosyvoice/utils/class_utils.py

+ 1 - 1
cosyvoice/cli/cosyvoice.py

@@ -207,7 +207,7 @@ class CosyVoice3(CosyVoice):
             raise ValueError('{} not found!'.format(hyper_yaml_path))
         with open(hyper_yaml_path, 'r') as f:
             configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
-        assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
+        assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
         self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                           configs['feat_extractor'],
                                           '{}/campplus.onnx'.format(model_dir),

+ 40 - 38
cosyvoice/llm/llm.py

@@ -56,8 +56,9 @@ class TransformerLM(torch.nn.Module):
         )
 
         # 2. build speech token language model related modules
-        self.sos_eos = 0
+        self.sos = 0
         self.task_id = 1
+        self.eos_token = self.speech_token_size
         self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
         self.llm = llm
         self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
@@ -85,10 +86,10 @@ class TransformerLM(torch.nn.Module):
         encoder_out = self.text_encoder_affine_layer(encoder_out)
         return encoder_out, encoder_out_lens
 
-    def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
+    def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
         text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
-        lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
+        lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
                     for i in range(len(text_token))]
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
         lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
@@ -127,14 +128,14 @@ class TransformerLM(torch.nn.Module):
         embedding = embedding.unsqueeze(1)
 
         # 3. eos and task_id
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
 
         # 4. encode speech_token
         speech_token = self.speech_embedding(speech_token)
 
         # 5. unpad and pad
-        lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
+        lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
                                                          task_id_emb, speech_token, speech_token_len)
 
         # 6. run lm forward
@@ -193,13 +194,13 @@ class TransformerLM(torch.nn.Module):
             embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
 
         # 3. concat llm_input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
+        lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
 
         # 4. cal min/max_length
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -215,11 +216,8 @@ class TransformerLM(torch.nn.Module):
                                                                   att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
                                                                                                  device=lm_input.device)).to(torch.bool))
             logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            # force continue decode first token
-            if i == 0:
-                logp[:, self.speech_token_size] = -float('inf')
             top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
-            if top_ids == self.speech_token_size:
+            if top_ids == self.eos_token:
                 break
             # in stream mode, yield token one by one
             yield top_ids
@@ -276,9 +274,10 @@ class Qwen2LM(TransformerLM):
         self.llm_output_size = llm_output_size
         self.speech_token_size = speech_token_size
         # 2. build speech token language model related modules
-        self.sos_eos = 0
+        self.sos = 0
         self.task_id = 1
-        self.fill_token = 2
+        self.eos_token = speech_token_size
+        self.fill_token = speech_token_size + 2
 
         self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
         self.llm = llm
@@ -312,7 +311,7 @@ class Qwen2LM(TransformerLM):
             if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
                 this_lm_target, this_lm_input = [], []
                 this_lm_target.append(IGNORE_ID)
-                this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
+                this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1))
                 for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
                     this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
                     this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -320,21 +319,21 @@ class Qwen2LM(TransformerLM):
                         assert len(this_speech_token) == self.mix_ratio[1]
                         this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
                         this_lm_target += this_speech_token
-                        this_lm_target.append(self.speech_token_size + 2)
+                        this_lm_target.append(self.fill_token)
                         this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
                         this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
                     else:
                         this_lm_target += [-1] * len(this_text_token)
                         this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
-                        this_lm_target.append(self.speech_token_size)
+                        this_lm_target.append(self.eos_token)
                         this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
                         this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
                         this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
                 this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
             # unistream sequence
             else:
-                this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
-                this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
+                this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
+                this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i],
                                               self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
             lm_target.append(this_lm_target)
             lm_input.append(this_lm_input)
@@ -445,13 +444,13 @@ class Qwen2LM(TransformerLM):
         text = self.llm.model.model.embed_tokens(text)
 
         # 3. concat llm_input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
+        lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
 
         # 4. cal min/max_length
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -501,10 +500,8 @@ class Qwen2LM(TransformerLM):
                                                           cache=cache)
                 logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
                 top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
-                if top_ids == self.speech_token_size:
+                if top_ids in self.stop_token_ids:
                     break
-                if top_ids > self.speech_token_size:
-                    continue
                 # in stream mode, yield token one by one
                 yield top_ids
                 out_tokens.append(top_ids)
@@ -526,13 +523,13 @@ class Qwen2LM(TransformerLM):
 
         device = prompt_text.device
         # 1. prepare input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb], dim=1)
+        lm_input = torch.concat([sos_emb], dim=1)
 
         # 2. iterate text
         out_tokens = []
@@ -554,12 +551,12 @@ class Qwen2LM(TransformerLM):
                     break
             # no prompt_speech_token_emb remain, can decode some speech token
             if prompt_speech_token_emb.size(1) == 0:
-                if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
+                if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
                     logging.info('get fill token, need to append more text token')
                     if text_cache.size(1) >= self.mix_ratio[0]:
                         lm_input_text = text_cache[:, :self.mix_ratio[0]]
                         logging.info('append {} text token'.format(lm_input_text.size(1)))
-                        if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
+                        if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
                             lm_input = lm_input_text
                         else:
                             lm_input = torch.concat([lm_input, lm_input_text], dim=1)
@@ -574,16 +571,16 @@ class Qwen2LM(TransformerLM):
                                                               cache=cache)
                     logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
                     if next_fill_index != -1 and len(out_tokens) == next_fill_index:
-                        top_ids = self.speech_token_size + 2
+                        top_ids = self.fill_token
                         next_fill_index += (self.mix_ratio[1] + 1)
                     else:
                         top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
-                    if top_ids == self.speech_token_size + 2:
+                    if top_ids == self.fill_token:
                         next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
                         logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
                     out_tokens.append(top_ids)
                     if top_ids >= self.speech_token_size:
-                        if top_ids == self.speech_token_size + 2:
+                        if top_ids == self.fill_token:
                             break
                         else:
                             raise ValueError('should not get token {}'.format(top_ids))
@@ -602,7 +599,7 @@ class Qwen2LM(TransformerLM):
             top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
             out_tokens.append(top_ids)
             if top_ids >= self.speech_token_size:
-                if top_ids == self.speech_token_size:
+                if top_ids == self.eos_token:
                     break
                 else:
                     raise ValueError('should not get token {}'.format(top_ids))
@@ -628,10 +625,10 @@ class CosyVoice3LM(Qwen2LM):
         self.llm_output_size = llm_output_size
         self.speech_token_size = speech_token_size
         # 2. build speech token language model related modules
-        self.sos = 0
-        self.eos = 1
-        self.task_id = 2
-        self.fill_token = 3
+        self.sos = speech_token_size + 0
+        self.eos_token = speech_token_size + 1
+        self.task_id = speech_token_size + 2
+        self.fill_token = speech_token_size + 3
 
         self.llm = llm
         self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
@@ -649,6 +646,11 @@ class CosyVoice3LM(Qwen2LM):
         self.sampling = sampling
         self.mix_ratio = mix_ratio
 
+        # 5. vllm related
+        self.stop_token_ids = [speech_token_size + i for i in range(4)]
+        self.vllm_output_queue = {}
+
+
     @torch.inference_mode()
     def inference(
             self,
@@ -670,13 +672,13 @@ class CosyVoice3LM(Qwen2LM):
         text = self.llm.model.model.embed_tokens(text)
 
         # 3. concat llm_input
-        sos_eos_emb = self.speech_embedding.weight[self.speech_token_size + self.sos].reshape(1, 1, -1)
-        task_id_emb = self.speech_embedding.weight[self.speech_token_size + self.task_id].reshape(1, 1, -1)
+        sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
+        task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
+        lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
 
         # 4. cal min/max_length
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)

+ 4 - 4
cosyvoice/utils/class_utils.py

@@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
                                              RelPositionMultiHeadedAttention)
 from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
 from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
-from cosyvoice.llm.llm import TransformerLM, Qwen2LM
+from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
 from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
 from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
-from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
 
 
 COSYVOICE_ACTIVATION_CLASSES = {
@@ -80,6 +80,6 @@ def get_model_type(configs):
         return CosyVoiceModel
     if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
         return CosyVoice2Model
-    if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
-        return CosyVoice2Model
+    if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
+        return CosyVoice3Model
     raise TypeError('No valid model type found!')