lyuxiang.lx 4 недель назад
Родитель
Сommit
f97d50d559
1 измененных файлов с 5 добавлено и 3 удалено
  1. 5 3
      cosyvoice/llm/llm.py

+ 5 - 3
cosyvoice/llm/llm.py

@@ -311,13 +311,15 @@ class Qwen2LM(TransformerLM):
         if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
             instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
             instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
+        else:
+            instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
+            instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
         for i in range(len(text_token)):
             # bistream sequence
             if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
                 this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
-                if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
-                    this_lm_target += [IGNORE_ID] * instruct_token_len[i]
-                    this_lm_input.append(instruct_token_emb[i])
+                this_lm_target += [IGNORE_ID] * instruct_token_len[i]
+                this_lm_input.append(instruct_token_emb[i])
                 for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
                     this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
                     this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()