|
|
@@ -15,6 +15,7 @@ from typing import Dict, Optional, Callable, List, Generator
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
import torch.nn.functional as F
|
|
|
+from transformers import Qwen2ForCausalLM
|
|
|
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
|
|
from cosyvoice.utils.common import IGNORE_ID
|
|
|
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
|
|
@@ -213,3 +214,127 @@ class TransformerLM(torch.nn.Module):
|
|
|
out_tokens.append(top_ids)
|
|
|
offset += lm_input.size(1)
|
|
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2Encoder(torch.nn.Module):
|
|
|
+ def __init__(self, pretrain_path):
|
|
|
+ super().__init__()
|
|
|
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
|
|
+
|
|
|
+ def forward_one_step(self, xs, masks, cache=None):
|
|
|
+ input_masks = masks[:, -1, :]
|
|
|
+ outs = self.model(
|
|
|
+ inputs_embeds=xs,
|
|
|
+ attention_mask=input_masks,
|
|
|
+ output_hidden_states=True,
|
|
|
+ return_dict=True,
|
|
|
+ use_cache=True,
|
|
|
+ past_key_values=cache,
|
|
|
+ )
|
|
|
+ xs = outs.hidden_states[-1]
|
|
|
+ new_cache = outs.past_key_values
|
|
|
+ return xs, new_cache
|
|
|
+
|
|
|
+
|
|
|
+class Qwen2LM(torch.nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ llm_input_size: int,
|
|
|
+ llm_output_size: int,
|
|
|
+ speech_token_size: int,
|
|
|
+ llm: torch.nn.Module,
|
|
|
+ sampling: Callable,
|
|
|
+ length_normalized_loss: bool = True,
|
|
|
+ lsm_weight: float = 0.0,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.llm_input_size = llm_input_size
|
|
|
+ self.llm_output_size = llm_output_size
|
|
|
+ self.speech_token_size = speech_token_size
|
|
|
+
|
|
|
+ # 2. build speech token language model related modules
|
|
|
+ self.sos_eos = 0
|
|
|
+ self.task_id = 1
|
|
|
+ self.fill_token = 2
|
|
|
+
|
|
|
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
|
|
+ self.llm = llm
|
|
|
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
|
|
|
+ self.criterion_ce = LabelSmoothingLoss(
|
|
|
+ size=speech_token_size + 3,
|
|
|
+ padding_idx=IGNORE_ID,
|
|
|
+ smoothing=lsm_weight,
|
|
|
+ normalize_length=length_normalized_loss,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 3. [Optional] build speech token related modules
|
|
|
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
|
|
+
|
|
|
+ # 4. sampling method
|
|
|
+ self.sampling = sampling
|
|
|
+
|
|
|
+ def sampling_ids(
|
|
|
+ self,
|
|
|
+ weighted_scores: torch.Tensor,
|
|
|
+ decoded_tokens: List,
|
|
|
+ sampling: int,
|
|
|
+ ignore_eos: bool = True,
|
|
|
+ ):
|
|
|
+ while True:
|
|
|
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
|
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
|
+ break
|
|
|
+ return top_ids
|
|
|
+
|
|
|
+ @torch.inference_mode()
|
|
|
+ def inference(
|
|
|
+ self,
|
|
|
+ text: torch.Tensor,
|
|
|
+ text_len: torch.Tensor,
|
|
|
+ prompt_text: torch.Tensor,
|
|
|
+ prompt_text_len: torch.Tensor,
|
|
|
+ prompt_speech_token: torch.Tensor,
|
|
|
+ prompt_speech_token_len: torch.Tensor,
|
|
|
+ embedding: torch.Tensor,
|
|
|
+ sampling: int = 25,
|
|
|
+ max_token_text_ratio: float = 20,
|
|
|
+ min_token_text_ratio: float = 2,
|
|
|
+ ) -> Generator[torch.Tensor, None, None]:
|
|
|
+ device = text.device
|
|
|
+ text = torch.concat([prompt_text, text], dim=1)
|
|
|
+ text_len += prompt_text_len
|
|
|
+ text = self.llm.model.model.embed_tokens(text)
|
|
|
+
|
|
|
+ # 2. encode embedding
|
|
|
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
|
+
|
|
|
+ # 3. concat llm_input
|
|
|
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
|
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ if prompt_speech_token_len != 0:
|
|
|
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
|
+ else:
|
|
|
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
|
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
|
+
|
|
|
+ # 4. cal min/max_length
|
|
|
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
|
|
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
|
|
+
|
|
|
+ # 5. step by step decode
|
|
|
+ out_tokens = []
|
|
|
+ cache = None
|
|
|
+ for i in range(max_len):
|
|
|
+ y_pred, cache = self.llm.forward_one_step(lm_input,
|
|
|
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
|
|
+ cache=cache)
|
|
|
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
|
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
|
|
+ if top_ids == self.speech_token_size:
|
|
|
+ break
|
|
|
+ if top_ids > self.speech_token_size:
|
|
|
+ continue
|
|
|
+ # in stream mode, yield token one by one
|
|
|
+ yield top_ids
|
|
|
+ out_tokens.append(top_ids)
|
|
|
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|