瀏覽代碼

Merge pull request #924 from FunAudioLLM/dev/lyuxiang.lx

add llm bistream
Xiang Lyu 8 月之前
父節點
當前提交
49761d2474

+ 10 - 0
README.md

@@ -143,6 +143,16 @@ for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒
 # instruct usage
 for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
     torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+# bistream usage, you can use generator as input, this is useful when using text llm model as input
+# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
+def text_generator():
+    yield '收到好友从远方寄来的生日礼物,'
+    yield '那份意外的惊喜与深深的祝福'
+    yield '让我心中充满了甜蜜的快乐,'
+    yield '笑容如花儿般绽放。'
+for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator, '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 ```
 
 **CosyVoice Usage**

+ 2 - 1
cosyvoice/cli/cosyvoice.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 import os
 import time
+from typing import Generator
 from tqdm import tqdm
 from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
@@ -76,7 +77,7 @@ class CosyVoice:
     def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
         prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
-            if len(i) < 0.5 * len(prompt_text):
+            if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
                 logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
             model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
             start_time = time.time()

+ 20 - 4
cosyvoice/cli/frontend.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from functools import partial
+from typing import Generator
 import json
 import onnxruntime
 import torch
@@ -31,6 +32,7 @@ except ImportError:
     from tn.chinese.normalizer import Normalizer as ZhNormalizer
     from tn.english.normalizer import Normalizer as EnNormalizer
     use_ttsfrd = False
+from cosyvoice.utils.file_utils import logging
 from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
 
 
@@ -71,10 +73,21 @@ class CosyVoiceFrontEnd:
             self.inflect_parser = inflect.engine()
 
     def _extract_text_token(self, text):
-        text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
-        text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
-        text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
-        return text_token, text_token_len
+        if isinstance(text, Generator):
+            logging.info('get tts_text generator, will return _extract_text_token_generator!')
+            # NOTE add a dummy text_token_len for compatibility
+            return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
+        else:
+            text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
+            text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
+            text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
+            return text_token, text_token_len
+
+    def _extract_text_token_generator(self, text_generator):
+        for text in text_generator:
+            text_token, _ = self._extract_text_token(text)
+            for i in range(text_token.shape[1]):
+                yield text_token[:, i: i + 1]
 
     def _extract_speech_token(self, speech):
         assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
@@ -106,6 +119,9 @@ class CosyVoiceFrontEnd:
         return speech_feat, speech_feat_len
 
     def text_normalize(self, text, split=True, text_frontend=True):
+        if isinstance(text, Generator):
+            logging.info('get tts_text generator, will skip text_normalize!')
+            return [text]
         if text_frontend is False:
             return [text] if split is True else text
         text = text.strip()

+ 19 - 8
cosyvoice/cli/model.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import os
+from typing import Generator
 import torch
 import numpy as np
 import threading
@@ -99,14 +100,24 @@ class CosyVoiceModel:
 
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context:
-            for i in self.llm.inference(text=text.to(self.device),
-                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
-                                        prompt_text=prompt_text.to(self.device),
-                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
-                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                        embedding=llm_embedding.to(self.device)):
-                self.tts_speech_token_dict[uuid].append(i)
+            if isinstance(text, Generator):
+                assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
+                for i in self.llm.inference_bistream(text=text,
+                                                     prompt_text=prompt_text.to(self.device),
+                                                     prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                                     prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                                     prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                     embedding=llm_embedding.to(self.device)):
+                    self.tts_speech_token_dict[uuid].append(i)
+            else:
+                for i in self.llm.inference(text=text.to(self.device),
+                                            text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+                                            prompt_text=prompt_text.to(self.device),
+                                            prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                            prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                            prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                            embedding=llm_embedding.to(self.device)):
+                    self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):

+ 107 - 23
cosyvoice/llm/llm.py

@@ -20,6 +20,7 @@ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
 from cosyvoice.utils.common import IGNORE_ID
 from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
 from cosyvoice.utils.common import th_accuracy
+from cosyvoice.utils.file_utils import logging
 
 
 class TransformerLM(torch.nn.Module):
@@ -144,10 +145,14 @@ class TransformerLM(torch.nn.Module):
             sampling: int,
             ignore_eos: bool = True,
     ):
+        num_trials, max_trials = 0, 100
         while True:
             top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
             if (not ignore_eos) or (self.speech_token_size not in top_ids):
                 break
+            num_trials += 1
+            if num_trials > max_trials:
+                raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
         return top_ids
 
     @torch.inference_mode()
@@ -239,7 +244,7 @@ class Qwen2Encoder(torch.nn.Module):
         return xs, new_cache
 
 
-class Qwen2LM(torch.nn.Module):
+class Qwen2LM(TransformerLM):
     def __init__(
             self,
             llm_input_size: int,
@@ -249,8 +254,9 @@ class Qwen2LM(torch.nn.Module):
             sampling: Callable,
             length_normalized_loss: bool = True,
             lsm_weight: float = 0.0,
+            mix_ratio: List[int] = [5, 15],
     ):
-        super().__init__()
+        torch.nn.Module.__init__(self)
         self.llm_input_size = llm_input_size
         self.llm_output_size = llm_output_size
         self.speech_token_size = speech_token_size
@@ -275,23 +281,7 @@ class Qwen2LM(torch.nn.Module):
 
         # 4. sampling method
         self.sampling = sampling
-
-    def sampling_ids(
-            self,
-            weighted_scores: torch.Tensor,
-            decoded_tokens: List,
-            sampling: int,
-            ignore_eos: bool = True,
-    ):
-        num_trials, max_trials = 0, 100
-        while True:
-            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
-            if (not ignore_eos) or (self.speech_token_size not in top_ids):
-                break
-            num_trials += 1
-            if num_trials > max_trials:
-                raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
-        return top_ids
+        self.mix_ratio = mix_ratio
 
     @torch.inference_mode()
     def inference(
@@ -312,9 +302,6 @@ class Qwen2LM(torch.nn.Module):
         text_len += prompt_text_len
         text = self.llm.model.model.embed_tokens(text)
 
-        # 2. encode embedding
-        embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
-
         # 3. concat llm_input
         sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
@@ -322,7 +309,7 @@ class Qwen2LM(torch.nn.Module):
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
+        lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
 
         # 4. cal min/max_length
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -345,3 +332,100 @@ class Qwen2LM(torch.nn.Module):
             yield top_ids
             out_tokens.append(top_ids)
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+
+    @torch.inference_mode()
+    def inference_bistream(
+            self,
+            text: Generator,
+            prompt_text: torch.Tensor,
+            prompt_text_len: torch.Tensor,
+            prompt_speech_token: torch.Tensor,
+            prompt_speech_token_len: torch.Tensor,
+            embedding: torch.Tensor,
+            sampling: int = 25,
+            max_token_text_ratio: float = 20,
+            min_token_text_ratio: float = 2,
+    ) -> Generator[torch.Tensor, None, None]:
+
+        device = prompt_text.device
+        # 1. prepare input
+        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+        if prompt_speech_token_len != 0:
+            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
+        else:
+            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
+        lm_input = torch.concat([sos_eos_emb], dim=1)
+
+        # 2. iterate text
+        out_tokens = []
+        cache = None
+        # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
+        text_cache = self.llm.model.model.embed_tokens(prompt_text)
+        next_fill_index = -1
+        for this_text in text:
+            text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
+            # prompt_speech_token_emb not empty, try append to lm_input
+            while prompt_speech_token_emb.size(1) != 0:
+                if text_cache.size(1) >= self.mix_ratio[0]:
+                    lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
+                    logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
+                    lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
+                    text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
+                else:
+                    logging.info('not enough text token to decode, wait for more')
+                    break
+            # no prompt_speech_token_emb remain, can decode some speech token
+            if prompt_speech_token_emb.size(1) == 0:
+                if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
+                    logging.info('get fill token, need to append more text token')
+                    if text_cache.size(1) >= self.mix_ratio[0]:
+                        lm_input_text = text_cache[:, :self.mix_ratio[0]]
+                        logging.info('append {} text token'.format(lm_input_text.size(1)))
+                        lm_input = torch.concat([lm_input, lm_input_text], dim=1)
+                        text_cache = text_cache[:, self.mix_ratio[0]:]
+                    else:
+                        logging.info('not enough text token to decode, wait for more')
+                        continue
+                while True:
+                    seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
+                    y_pred, cache = self.llm.forward_one_step(lm_input,
+                                                masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
+                                                cache=cache)
+                    logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
+                    if next_fill_index != -1 and len(out_tokens) == next_fill_index:
+                        top_ids = self.speech_token_size + 2
+                        next_fill_index += (self.mix_ratio[1] + 1)
+                    else:
+                        top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
+                    if top_ids == self.speech_token_size + 2:
+                        next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
+                        logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
+                    out_tokens.append(top_ids)
+                    if top_ids >= self.speech_token_size:
+                        if top_ids == self.speech_token_size + 2:
+                            break
+                        else:
+                            raise ValueError('should not get token {}'.format(top_ids))
+                    yield top_ids
+                    lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+
+        # 3. final decode
+        lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
+        logging.info('no more text token, decode until met eos')
+        while True:
+            seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
+            y_pred, cache = self.llm.forward_one_step(lm_input,
+                                                      masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
+                                                      cache=cache)
+            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
+            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
+            out_tokens.append(top_ids)
+            if top_ids >= self.speech_token_size:
+                if top_ids == self.speech_token_size:
+                    break
+                else:
+                    raise ValueError('should not get token {}'.format(top_ids))
+            # in stream mode, yield token one by one
+            yield top_ids
+            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

+ 1 - 1
cosyvoice/utils/common.py

@@ -162,5 +162,5 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
     # attention mask bias
     # NOTE(Mddct): torch.finfo jit issues
     #     chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
-    mask = (1.0 - mask) * torch.finfo(dtype).min
+    mask = (1.0 - mask) * -1.0e+10
     return mask

+ 1 - 0
examples/libritts/cosyvoice2/cosyvoice

@@ -0,0 +1 @@
+../../../cosyvoice

+ 1 - 0
examples/libritts/cosyvoice2/tools

@@ -0,0 +1 @@
+../../../tools

+ 1 - 1
tools/extract_speech_token.py

@@ -24,7 +24,7 @@ import whisper
 
 
 def single_job(utt):
-    audio, sample_rate = torchaudio.load(utt2wav[utt])
+    audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile')
     if sample_rate != 16000:
         audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
     if audio.shape[1] / 16000 > 30: