|
|
@@ -154,14 +154,9 @@ class TransformerLM(torch.nn.Module):
|
|
|
sampling: int,
|
|
|
ignore_eos: bool = True,
|
|
|
):
|
|
|
- num_trials, max_trials = 0, 100
|
|
|
- while True:
|
|
|
- top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
|
- if (not ignore_eos) or (top_ids < self.speech_token_size):
|
|
|
- break
|
|
|
- num_trials += 1
|
|
|
- if num_trials > max_trials:
|
|
|
- raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
|
|
+ if ignore_eos is True:
|
|
|
+ weighted_scores[self.speech_token_size] = -float('inf')
|
|
|
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
|
return top_ids
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
@@ -365,34 +360,48 @@ class Qwen2LM(TransformerLM):
|
|
|
audio: (B, T, N) or (B, T)
|
|
|
audio_lengths: (B,)
|
|
|
"""
|
|
|
+ # 1. encode text_token
|
|
|
text_token = batch['text_token'].to(device)
|
|
|
text_token_len = batch['text_token_len'].to(device)
|
|
|
+ text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
|
+
|
|
|
+ # 2. encode speech_token
|
|
|
if 'speech_token' not in batch:
|
|
|
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
|
|
else:
|
|
|
speech_token = batch['speech_token'].to(device)
|
|
|
speech_token_len = batch['speech_token_len'].to(device)
|
|
|
-
|
|
|
- # 1. encode text_token
|
|
|
- text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
|
-
|
|
|
- # 3. sos and task_id
|
|
|
- sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
- task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
-
|
|
|
- # 2. encode speech_token
|
|
|
speech_token_emb = self.speech_embedding(speech_token)
|
|
|
|
|
|
- # 3. prepare llm_input/target
|
|
|
- lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
|
|
- speech_token, speech_token_emb, speech_token_len)
|
|
|
+ # 3. sos and task_id
|
|
|
+ if self.__class__.__name__ == 'CosyVoice3LM':
|
|
|
+ sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
+ task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ elif self.__class__.__name__ == 'Qwen2LM':
|
|
|
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ else:
|
|
|
+ raise ValueError
|
|
|
+
|
|
|
+ # 4. prepare llm_input/target
|
|
|
+ if self.__class__.__name__ == 'CosyVoice3LM':
|
|
|
+ instruct_token = batch['instruct_token'].to(device)
|
|
|
+ instruct_token_len = batch['instruct_token_len'].to(device)
|
|
|
+ instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
|
|
|
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
|
|
+ speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
|
|
|
+ elif self.__class__.__name__ == 'Qwen2LM':
|
|
|
+ lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
|
|
+ speech_token, speech_token_emb, speech_token_len)
|
|
|
+ else:
|
|
|
+ raise ValueError
|
|
|
lm_target = lm_target.to(device)
|
|
|
|
|
|
# 4. run lm forward
|
|
|
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
|
|
logits = self.llm_decoder(lm_output)
|
|
|
loss = self.criterion_ce(logits, lm_target.to(device))
|
|
|
- acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
|
|
+ acc = th_accuracy(logits.view(-1, self.llm_decoder.out_features), lm_target, ignore_label=IGNORE_ID)
|
|
|
return {'loss': loss, 'acc': acc}
|
|
|
|
|
|
def forward_dpo(
|
|
|
@@ -464,16 +473,25 @@ class Qwen2LM(TransformerLM):
|
|
|
device = text.device
|
|
|
text = torch.concat([prompt_text, text], dim=1)
|
|
|
text_len += prompt_text_len
|
|
|
- text = self.llm.model.model.embed_tokens(text)
|
|
|
+ text_emb = self.llm.model.model.embed_tokens(text)
|
|
|
+ if self.__class__.__name__ == 'CosyVoice3LM':
|
|
|
+ # NOTE temporary hardcode, 151646 is <|endofprompt|> token
|
|
|
+ assert 151646 in text, '<|endofprompt|> not detected in CosyVoice3 text or prompt_text, check your input!'
|
|
|
|
|
|
# 3. concat llm_input
|
|
|
- sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
- task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ if self.__class__.__name__ == 'CosyVoice3LM':
|
|
|
+ sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
+ task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ elif self.__class__.__name__ == 'Qwen2LM':
|
|
|
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ else:
|
|
|
+ raise ValueError
|
|
|
if prompt_speech_token_len != 0:
|
|
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
|
else:
|
|
|
- prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
|
- lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
|
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text_emb.dtype).to(device)
|
|
|
+ lm_input = torch.concat([sos_emb, text_emb, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
|
|
|
|
# 4. cal min/max_length
|
|
|
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
|
|
@@ -546,8 +564,14 @@ class Qwen2LM(TransformerLM):
|
|
|
|
|
|
device = prompt_text.device
|
|
|
# 1. prepare input
|
|
|
- sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
- task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ if self.__class__.__name__ == 'CosyVoice3LM':
|
|
|
+ sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
+ task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ elif self.__class__.__name__ == 'Qwen2LM':
|
|
|
+ sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ else:
|
|
|
+ raise ValueError
|
|
|
if prompt_speech_token_len != 0:
|
|
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
|
else:
|
|
|
@@ -558,6 +582,12 @@ class Qwen2LM(TransformerLM):
|
|
|
out_tokens = []
|
|
|
cache = None
|
|
|
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
|
|
+ if self.__class__.__name__ == 'CosyVoice3LM':
|
|
|
+ # NOTE temporary hardcode, 151646 is <|endofprompt|> token
|
|
|
+ assert 151646 in prompt_text, '<|endofprompt|> not detected in CosyVoice3 prompt_text, check your input!'
|
|
|
+ eop_index = prompt_text.flatten().tolist().index(151646)
|
|
|
+ lm_input = torch.concat([lm_input, self.llm.model.model.embed_tokens(prompt_text[:, :eop_index + 1])], dim=1)
|
|
|
+ prompt_text = prompt_text[:, eop_index + 1:]
|
|
|
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
|
|
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
|
|
|
for this_text in text:
|
|
|
@@ -673,88 +703,4 @@ class CosyVoice3LM(Qwen2LM):
|
|
|
self.stop_token_ids = [speech_token_size + i for i in range(200)]
|
|
|
self.vllm_output_queue = {}
|
|
|
if online_feature is True:
|
|
|
- self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
|
|
|
-
|
|
|
- def forward(
|
|
|
- self,
|
|
|
- batch: dict,
|
|
|
- device: torch.device,
|
|
|
- ) -> Dict[str, Optional[torch.Tensor]]:
|
|
|
- """
|
|
|
- Args:
|
|
|
- text: (B, L, D)
|
|
|
- text_lengths: (B,)
|
|
|
- audio: (B, T, N) or (B, T)
|
|
|
- audio_lengths: (B,)
|
|
|
- """
|
|
|
- text_token = batch['text_token'].to(device)
|
|
|
- text_token_len = batch['text_token_len'].to(device)
|
|
|
- if 'speech_token' not in batch:
|
|
|
- speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
|
|
- else:
|
|
|
- speech_token = batch['speech_token'].to(device)
|
|
|
- speech_token_len = batch['speech_token_len'].to(device)
|
|
|
-
|
|
|
- # NOTE should append instruct_token to sequence, not implemented yet
|
|
|
- instruct_token = batch['instruct_token'].to(device)
|
|
|
- instruct_token_len = batch['instruct_token_len'].to(device)
|
|
|
-
|
|
|
- # 1. encode text_token
|
|
|
- text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
|
- instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
|
|
|
-
|
|
|
- # 3. sos and task_id
|
|
|
- sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
- task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
-
|
|
|
- # 2. encode speech_token
|
|
|
- speech_token_emb = self.speech_embedding(speech_token)
|
|
|
-
|
|
|
- # 3. prepare llm_input/target
|
|
|
- lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
|
|
- speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
|
|
|
- lm_target = lm_target.to(device)
|
|
|
-
|
|
|
- # 4. run lm forward
|
|
|
- lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
|
|
- logits = self.llm_decoder(lm_output)
|
|
|
- loss = self.criterion_ce(logits, lm_target.to(device))
|
|
|
- acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
|
|
|
- return {'loss': loss, 'acc': acc}
|
|
|
-
|
|
|
- @torch.inference_mode()
|
|
|
- def inference(
|
|
|
- self,
|
|
|
- text: torch.Tensor,
|
|
|
- text_len: torch.Tensor,
|
|
|
- prompt_text: torch.Tensor,
|
|
|
- prompt_text_len: torch.Tensor,
|
|
|
- prompt_speech_token: torch.Tensor,
|
|
|
- prompt_speech_token_len: torch.Tensor,
|
|
|
- embedding: torch.Tensor,
|
|
|
- sampling: int = 25,
|
|
|
- max_token_text_ratio: float = 20,
|
|
|
- min_token_text_ratio: float = 2,
|
|
|
- uuid: str = '',
|
|
|
- ) -> Generator[torch.Tensor, None, None]:
|
|
|
- device = text.device
|
|
|
- text = torch.concat([prompt_text, text], dim=1)
|
|
|
- text_len += prompt_text_len
|
|
|
- text = self.llm.model.model.embed_tokens(text)
|
|
|
-
|
|
|
- # 3. concat llm_input
|
|
|
- sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
|
|
- task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
- if prompt_speech_token_len != 0:
|
|
|
- prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
|
- else:
|
|
|
- prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
|
- lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
|
-
|
|
|
- # 4. cal min/max_length
|
|
|
- min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
|
|
- max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
|
|
-
|
|
|
- # 5. step by step decode
|
|
|
- for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
|
|
- yield token
|
|
|
+ self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
|