lyuxiang.lx 5 meses atrás
pai
commit
70991d7327
2 arquivos alterados com 128 adições e 4 exclusões
  1. 76 0
      cosyvoice/llm/llm.py
  2. 52 4
      cosyvoice/tokenizer/tokenizer.py

+ 76 - 0
cosyvoice/llm/llm.py

@@ -609,3 +609,79 @@ class Qwen2LM(TransformerLM):
             # in stream mode, yield token one by one
             yield top_ids
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+
+
+class CosyVoice3LM(Qwen2LM):
+    def __init__(
+            self,
+            llm_input_size: int,
+            llm_output_size: int,
+            speech_token_size: int,
+            llm: torch.nn.Module,
+            sampling: Callable,
+            length_normalized_loss: bool = True,
+            lsm_weight: float = 0.0,
+            mix_ratio: List[int] = [5, 15],
+    ):
+        torch.nn.Module.__init__(self)
+        self.llm_input_size = llm_input_size
+        self.llm_output_size = llm_output_size
+        self.speech_token_size = speech_token_size
+        # 2. build speech token language model related modules
+        self.sos = 0
+        self.eos = 1
+        self.task_id = 2
+        self.fill_token = 3
+
+        self.llm = llm
+        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
+        self.criterion_ce = LabelSmoothingLoss(
+            size=speech_token_size + 200,
+            padding_idx=IGNORE_ID,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        # 3. [Optional] build speech token related modules
+        self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
+
+        # 4. sampling method
+        self.sampling = sampling
+        self.mix_ratio = mix_ratio
+
+    @torch.inference_mode()
+    def inference(
+            self,
+            text: torch.Tensor,
+            text_len: torch.Tensor,
+            prompt_text: torch.Tensor,
+            prompt_text_len: torch.Tensor,
+            prompt_speech_token: torch.Tensor,
+            prompt_speech_token_len: torch.Tensor,
+            embedding: torch.Tensor,
+            sampling: int = 25,
+            max_token_text_ratio: float = 20,
+            min_token_text_ratio: float = 2,
+            uuid: str = '',
+    ) -> Generator[torch.Tensor, None, None]:
+        device = text.device
+        text = torch.concat([prompt_text, text], dim=1)
+        text_len += prompt_text_len
+        text = self.llm.model.model.embed_tokens(text)
+
+        # 3. concat llm_input
+        sos_eos_emb = self.speech_embedding.weight[self.speech_token_size + self.sos].reshape(1, 1, -1)
+        task_id_emb = self.speech_embedding.weight[self.speech_token_size + self.task_id].reshape(1, 1, -1)
+        if prompt_speech_token_len != 0:
+            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
+        else:
+            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+        lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
+
+        # 4. cal min/max_length
+        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
+        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
+
+        # 5. step by step decode
+        for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
+            yield token

+ 52 - 4
cosyvoice/tokenizer/tokenizer.py

@@ -238,7 +238,7 @@ def get_tokenizer(
     )
 
 
-class QwenTokenizer():
+class CosyVoice2Tokenizer():
     def __init__(self, token_path, skip_special_tokens=True):
         super().__init__()
         # NOTE: non-chat model, all these special tokens keep randomly initialized.
@@ -271,9 +271,57 @@ class QwenTokenizer():
         return text
 
 
+class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
+    def __init__(self, token_path, skip_special_tokens=True):
+        # NOTE: non-chat model, all these special tokens keep randomly initialized.
+        special_tokens = {
+            'eos_token': '<|endoftext|>',
+            'pad_token': '<|endoftext|>',
+            'additional_special_tokens': [
+                '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
+                '[breath]', '<strong>', '</strong>', '[noise]',
+                '[laughter]', '[cough]', '[clucking]', '[accent]',
+                '[quick_breath]',
+                "<laughter>", "</laughter>",
+                "[hissing]", "[sigh]", "[vocalized-noise]",
+                "[lipsmack]", "[mn]", "<|endofsystem|>",
+                "[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
+                "[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
+                "[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
+                "[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
+                "[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
+                "[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
+                "[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
+                "[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
+                "[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
+                "[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
+                "[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
+                "[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
+                "[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
+                "[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
+                "[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
+                "[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
+                "[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
+                "[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
+                "[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
+                "[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
+            ]
+        }
+        self.special_tokens = special_tokens
+        self.tokenizer = AutoTokenizer.from_pretrained(token_path)
+        self.tokenizer.add_special_tokens(special_tokens)
+        self.skip_special_tokens = skip_special_tokens
+
+
 @lru_cache(maxsize=None)
 def get_qwen_tokenizer(
     token_path: str,
-    skip_special_tokens: bool
-) -> QwenTokenizer:
-    return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
+    skip_special_tokens: bool,
+    version: str = 'cosyvoice2'
+):
+    if version == 'cosyvoice2':
+        return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
+    elif version == 'cosyvoice3':
+        return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
+    else:
+        raise ValueError