lyuxiang.lx 8 months ago
parent
commit
f76f5abcc1
1 changed files with 53 additions and 9 deletions
  1. 53 9
      cosyvoice/llm/llm.py

+ 53 - 9
cosyvoice/llm/llm.py

@@ -127,7 +127,7 @@ class TransformerLM(torch.nn.Module):
         embedding = self.spk_embed_affine_layer(embedding)
         embedding = embedding.unsqueeze(1)
 
-        # 3. eos and task_id
+        # 3. sos and task_id
         sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
 
@@ -300,7 +300,7 @@ class Qwen2LM(TransformerLM):
         self.stop_token_ids = [speech_token_size + i for i in range(3)]
         self.vllm_output_queue = {}
 
-    def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
+    def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
         lm_target, lm_input = [], []
         text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
@@ -311,7 +311,7 @@ class Qwen2LM(TransformerLM):
             if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
                 this_lm_target, this_lm_input = [], []
                 this_lm_target.append(IGNORE_ID)
-                this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1))
+                this_lm_input.append(sos_emb.squeeze(dim=0))
                 for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
                     this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
                     this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -327,14 +327,13 @@ class Qwen2LM(TransformerLM):
                         this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
                         this_lm_target.append(self.eos_token)
                         this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
-                        this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
+                        this_lm_input.append(task_id_emb.squeeze(dim=0))
                         this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
                 this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
             # unistream sequence
             else:
                 this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
-                this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i],
-                                              self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
+                this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
             lm_target.append(this_lm_target)
             lm_input.append(this_lm_input)
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
@@ -362,11 +361,15 @@ class Qwen2LM(TransformerLM):
         # 1. encode text_token
         text_token_emb = self.llm.model.model.embed_tokens(text_token)
 
+        # 3. sos and task_id
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
+        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+
         # 2. encode speech_token
         speech_token_emb = self.speech_embedding(speech_token)
 
         # 3. prepare llm_input/target
-        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
+        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len)
         lm_target = lm_target.to(device)
 
         # 4. run lm forward
@@ -391,6 +394,10 @@ class Qwen2LM(TransformerLM):
         # 1. encode text_token
         text_token_emb = self.llm.model.model.embed_tokens(text_token)
 
+        # 3. sos and task_id
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
+        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+
         # 2. encode speech_token
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
         reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
@@ -400,8 +407,8 @@ class Qwen2LM(TransformerLM):
         speech_token_combined_emb = self.speech_embedding(speech_token_combined)
 
         # 3. prepare llm_input/target
-        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
-                                                                         speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
+        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
+                                                                         task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
         lm_target = lm_target.to(device)
 
         # 4. run lm forward
@@ -650,6 +657,43 @@ class CosyVoice3LM(Qwen2LM):
         self.stop_token_ids = [speech_token_size + i for i in range(4)]
         self.vllm_output_queue = {}
 
+    def forward(
+            self,
+            batch: dict,
+            device: torch.device,
+    ) -> Dict[str, Optional[torch.Tensor]]:
+        """
+        Args:
+            text: (B, L, D)
+            text_lengths: (B,)
+            audio: (B, T, N) or (B, T)
+            audio_lengths: (B,)
+        """
+        text_token = batch['text_token'].to(device)
+        text_token_len = batch['text_token_len'].to(device)
+        speech_token = batch['speech_token'].to(device)
+        speech_token_len = batch['speech_token_len'].to(device)
+
+        # 1. encode text_token
+        text_token_emb = self.llm.model.model.embed_tokens(text_token)
+
+        # 3. sos and task_id
+        sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
+        task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
+
+        # 2. encode speech_token
+        speech_token_emb = self.speech_embedding(speech_token)
+
+        # 3. prepare llm_input/target
+        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len)
+        lm_target = lm_target.to(device)
+
+        # 4. run lm forward
+        lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
+        logits = self.llm_decoder(lm_output)
+        loss = self.criterion_ce(logits, lm_target.to(device))
+        acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
+        return {'loss': loss, 'acc': acc}
 
     @torch.inference_mode()
     def inference(