|
|
@@ -301,18 +301,23 @@ class Qwen2LM(TransformerLM):
|
|
|
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
|
|
self.vllm_output_queue = {}
|
|
|
|
|
|
- def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
|
|
|
+ def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
|
|
|
lm_target, lm_input = [], []
|
|
|
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
|
|
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
|
|
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
|
|
|
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
|
|
|
+ # NOTE add instruct_token in CosyVoice3
|
|
|
+ if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
|
|
+ instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
|
|
|
+ instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
|
|
|
for i in range(len(text_token)):
|
|
|
# bistream sequence
|
|
|
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
|
|
- this_lm_target, this_lm_input = [], []
|
|
|
- this_lm_target.append(IGNORE_ID)
|
|
|
- this_lm_input.append(sos_emb.squeeze(dim=0))
|
|
|
+ this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
|
|
|
+ if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
|
|
+ this_lm_target += [IGNORE_ID] * instruct_token_len[i]
|
|
|
+ this_lm_input.append(instruct_token_emb[i])
|
|
|
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
|
|
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
|
|
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
|
|
@@ -333,8 +338,8 @@ class Qwen2LM(TransformerLM):
|
|
|
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
|
|
# unistream sequence
|
|
|
else:
|
|
|
- this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
|
|
- this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
|
|
|
+ this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
|
|
+ this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
|
|
|
lm_target.append(this_lm_target)
|
|
|
lm_input.append(this_lm_input)
|
|
|
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
|
|
@@ -681,6 +686,7 @@ class CosyVoice3LM(Qwen2LM):
|
|
|
|
|
|
# 1. encode text_token
|
|
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
|
+ instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
|
|
|
|
|
|
# 3. sos and task_id
|
|
|
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
@@ -691,7 +697,7 @@ class CosyVoice3LM(Qwen2LM):
|
|
|
|
|
|
# 3. prepare llm_input/target
|
|
|
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
|
|
- speech_token, speech_token_emb, speech_token_len)
|
|
|
+ speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
|
|
|
lm_target = lm_target.to(device)
|
|
|
|
|
|
# 4. run lm forward
|