1
0
lyuxiang.lx преди 1 година
родител
ревизия
cde3cec6fa
променени са 1 файла, в които са добавени 125 реда и са изтрити 0 реда
  1. 125 0
      cosyvoice/llm/llm.py

+ 125 - 0
cosyvoice/llm/llm.py

@@ -15,6 +15,7 @@ from typing import Dict, Optional, Callable, List, Generator
 import torch
 from torch import nn
 import torch.nn.functional as F
+from transformers import Qwen2ForCausalLM
 from torch.nn.utils.rnn import pad_sequence, unpad_sequence
 from cosyvoice.utils.common import IGNORE_ID
 from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
@@ -213,3 +214,127 @@ class TransformerLM(torch.nn.Module):
             out_tokens.append(top_ids)
             offset += lm_input.size(1)
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+
+
+class Qwen2Encoder(torch.nn.Module):
+    def __init__(self, pretrain_path):
+        super().__init__()
+        self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
+
+    def forward_one_step(self, xs, masks, cache=None):
+        input_masks = masks[:, -1, :]
+        outs = self.model(
+            inputs_embeds=xs,
+            attention_mask=input_masks,
+            output_hidden_states=True,
+            return_dict=True,
+            use_cache=True,
+            past_key_values=cache,
+        )
+        xs = outs.hidden_states[-1]
+        new_cache = outs.past_key_values
+        return xs, new_cache
+
+
+class Qwen2LM(torch.nn.Module):
+    def __init__(
+            self,
+            llm_input_size: int,
+            llm_output_size: int,
+            speech_token_size: int,
+            llm: torch.nn.Module,
+            sampling: Callable,
+            length_normalized_loss: bool = True,
+            lsm_weight: float = 0.0,
+    ):
+        super().__init__()
+        self.llm_input_size = llm_input_size
+        self.llm_output_size = llm_output_size
+        self.speech_token_size = speech_token_size
+
+        # 2. build speech token language model related modules
+        self.sos_eos = 0
+        self.task_id = 1
+        self.fill_token = 2
+
+        self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
+        self.llm = llm
+        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
+        self.criterion_ce = LabelSmoothingLoss(
+            size=speech_token_size + 3,
+            padding_idx=IGNORE_ID,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        # 3. [Optional] build speech token related modules
+        self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
+
+        # 4. sampling method
+        self.sampling = sampling
+
+    def sampling_ids(
+            self,
+            weighted_scores: torch.Tensor,
+            decoded_tokens: List,
+            sampling: int,
+            ignore_eos: bool = True,
+    ):
+        while True:
+            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
+            if (not ignore_eos) or (self.speech_token_size not in top_ids):
+                break
+        return top_ids
+
+    @torch.inference_mode()
+    def inference(
+            self,
+            text: torch.Tensor,
+            text_len: torch.Tensor,
+            prompt_text: torch.Tensor,
+            prompt_text_len: torch.Tensor,
+            prompt_speech_token: torch.Tensor,
+            prompt_speech_token_len: torch.Tensor,
+            embedding: torch.Tensor,
+            sampling: int = 25,
+            max_token_text_ratio: float = 20,
+            min_token_text_ratio: float = 2,
+    ) -> Generator[torch.Tensor, None, None]:
+        device = text.device
+        text = torch.concat([prompt_text, text], dim=1)
+        text_len += prompt_text_len
+        text = self.llm.model.model.embed_tokens(text)
+
+        # 2. encode embedding
+        embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+
+        # 3. concat llm_input
+        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+        if prompt_speech_token_len != 0:
+            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
+        else:
+            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+        lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
+
+        # 4. cal min/max_length
+        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
+        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
+
+        # 5. step by step decode
+        out_tokens = []
+        cache = None
+        for i in range(max_len):
+            y_pred, cache = self.llm.forward_one_step(lm_input,
+                                                      masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
+                                                      cache=cache)
+            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
+            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
+            if top_ids == self.speech_token_size:
+                break
+            if top_ids > self.speech_token_size:
+                continue
+            # in stream mode, yield token one by one
+            yield top_ids
+            out_tokens.append(top_ids)
+            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)