Browse Source

Merge pull request #715 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
Xiang Lyu 11 months ago
parent
commit
94d6ce1006

+ 2 - 1
.gitignore

@@ -48,4 +48,5 @@ compile_commands.json
 *.pt
 pretrained_models/*
 *_pb2_grpc.py
-*_pb2.py
+*_pb2.py
+*.tar

+ 21 - 8
README.md

@@ -85,6 +85,7 @@ If you are expert in this field, and you are only interested in training your ow
 ``` python
 # SDK模型下载
 from modelscope import snapshot_download
+snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
 snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
 snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
 snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
@@ -95,6 +96,7 @@ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice
 ``` sh
 # git模型下载,请确保已安装git lfs
 mkdir -p pretrained_models
+git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
 git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
 git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
 git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
@@ -109,11 +111,13 @@ Notice that this step is not necessary. If you do not install `ttsfrd` package,
 ``` sh
 cd pretrained_models/CosyVoice-ttsfrd/
 unzip resource.zip -d .
-pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
+pip install ttsfrd_dependency-0.1-py3-none-any.whl
+pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
 ```
 
 **Basic Usage**
 
+We strongly recommend using `CosyVoice2-0.5B` for better performance.
 For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
 For sft inference, please use `CosyVoice-300M-SFT` model.
 For instruct inference, please use `CosyVoice-300M-Instruct` model.
@@ -124,36 +128,45 @@ export PYTHONPATH=third_party/Matcha-TTS
 ```
 
 ``` python
-from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
 from cosyvoice.utils.file_utils import load_wav
 import torchaudio
 
+# cosyvoice2
+cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
+
+# zero_shot usage
+prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
+for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+# cosyvoice
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
 # sft usage
 print(cosyvoice.list_avaliable_spks())
 # change stream=True for chunk stream inference
 for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
-    torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
+    torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference
 # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
 prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
 for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
-    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 # cross_lingual usage
 prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
 for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
-    torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
+    torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 # vc usage
 prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
 source_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
 for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
-    torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], 22050)
+    torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
 # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
 for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
-    torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
+    torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 ```
 
 **Start web demo**
@@ -207,4 +220,4 @@ You can also scan the QR code to join our official Dingding chat group.
 5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
 
 ## Disclaimer
-The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
+The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.

+ 46 - 9
cosyvoice/cli/cosyvoice.py

@@ -18,7 +18,7 @@ from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
 import torch
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
-from cosyvoice.cli.model import CosyVoiceModel
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
 from cosyvoice.utils.file_utils import logging
 
 
@@ -38,6 +38,7 @@ class CosyVoice:
                                           '{}/spk2info.pt'.format(model_dir),
                                           instruct,
                                           configs['allowed_special'])
+        self.sample_rate = configs['sample_rate']
         if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
             load_jit = False
             fp16 = False
@@ -64,7 +65,7 @@ class CosyVoice:
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / 22050
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
@@ -74,11 +75,11 @@ class CosyVoice:
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
             if len(i) < 0.5 * len(prompt_text):
                 logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
-            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
+            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / 22050
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
@@ -87,11 +88,11 @@ class CosyVoice:
         if self.frontend.instruct is True:
             raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
-            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
+            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / 22050
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
@@ -105,16 +106,52 @@ class CosyVoice:
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / 22050
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
 
     def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
-        model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
+        model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
         start_time = time.time()
         for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
-            speech_len = model_output['tts_speech'].shape[1] / 22050
+            speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
             logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
             yield model_output
             start_time = time.time()
+
+
+class CosyVoice2(CosyVoice):
+
+    def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
+        instruct = True if '-Instruct' in model_dir else False
+        self.model_dir = model_dir
+        if not os.path.exists(model_dir):
+            model_dir = snapshot_download(model_dir)
+        with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
+            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'Qwen2-0.5B-CosyVoice-BlankEN')})
+        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+                                          configs['feat_extractor'],
+                                          '{}/campplus.onnx'.format(model_dir),
+                                          '{}/speech_tokenizer_v2.onnx'.format(model_dir),
+                                          '{}/spk2info.pt'.format(model_dir),
+                                          instruct,
+                                          configs['allowed_special'])
+        self.sample_rate = configs['sample_rate']
+        if torch.cuda.is_available() is False and load_jit is True:
+            load_jit = False
+            logging.warning('cpu do not support jit, force set to False')
+        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
+        self.model.load('{}/llm.pt'.format(model_dir),
+                        '{}/flow.pt'.format(model_dir),
+                        '{}/hift.pt'.format(model_dir))
+        if load_jit:
+            self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
+        if load_trt is True and load_onnx is True:
+            load_onnx = False
+            logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
+        if load_onnx:
+            self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
+        if load_trt:
+            self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
+        del configs

+ 26 - 25
cosyvoice/cli/frontend.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from functools import partial
+import json
 import onnxruntime
 import torch
 import numpy as np
@@ -66,9 +67,7 @@ class CosyVoiceFrontEnd:
             ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
             assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
                 'failed to initialize ttsfrd resource'
-            self.frd.set_lang_type('pinyin')
-            self.frd.enable_pinyin_mix(True)
-            self.frd.set_breakmodel_index(1)
+            self.frd.set_lang_type('pinyinvg')
         else:
             self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
             self.en_tn_model = EnNormalizer()
@@ -112,26 +111,28 @@ class CosyVoiceFrontEnd:
         text = text.strip()
         if contains_chinese(text):
             if self.use_ttsfrd:
-                text = self.frd.get_frd_extra_info(text, 'input')
+                texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
+                text = ''.join(texts)
             else:
                 text = self.zh_tn_model.normalize(text)
-            text = text.replace("\n", "")
-            text = replace_blank(text)
-            text = replace_corner_mark(text)
-            text = text.replace(".", "。")
-            text = text.replace(" - ", ",")
-            text = remove_bracket(text)
-            text = re.sub(r'[,,、]+$', '。', text)
-            texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
-                                         token_min_n=60, merge_len=20, comma_split=False))
+                text = text.replace("\n", "")
+                text = replace_blank(text)
+                text = replace_corner_mark(text)
+                text = text.replace(".", "。")
+                text = text.replace(" - ", ",")
+                text = remove_bracket(text)
+                text = re.sub(r'[,,、]+$', '。', text)
+                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
+                                             token_min_n=60, merge_len=20, comma_split=False))
         else:
             if self.use_ttsfrd:
-                text = self.frd.get_frd_extra_info(text, 'input')
+                texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
+                text = ''.join(texts)
             else:
                 text = self.en_tn_model.normalize(text)
-            text = spell_out_number(text, self.inflect_parser)
-            texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
-                                         token_min_n=60, merge_len=20, comma_split=False))
+                text = spell_out_number(text, self.inflect_parser)
+                texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
+                                             token_min_n=60, merge_len=20, comma_split=False))
         if split is False:
             return text
         return texts
@@ -142,11 +143,11 @@ class CosyVoiceFrontEnd:
         model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
         return model_input
 
-    def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
+    def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
         tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
         prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
-        prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
-        speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
+        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
+        speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
         speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
         embedding = self._extract_spk_embedding(prompt_speech_16k)
         model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
@@ -157,8 +158,8 @@ class CosyVoiceFrontEnd:
                        'llm_embedding': embedding, 'flow_embedding': embedding}
         return model_input
 
-    def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
-        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
+    def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
+        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
         # in cross lingual mode, we remove prompt in llm
         del model_input['prompt_text']
         del model_input['prompt_text_len']
@@ -175,10 +176,10 @@ class CosyVoiceFrontEnd:
         model_input['prompt_text_len'] = instruct_text_token_len
         return model_input
 
-    def frontend_vc(self, source_speech_16k, prompt_speech_16k):
+    def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
         prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
-        prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
-        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
+        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
+        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
         embedding = self._extract_spk_embedding(prompt_speech_16k)
         source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
         model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,

+ 168 - 3
cosyvoice/cli/model.py

@@ -57,15 +57,15 @@ class CosyVoiceModel:
         self.hift_cache_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
-        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
+        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
         self.llm.to(self.device).eval()
         if self.fp16 is True:
             self.llm.half()
-        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
+        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
         self.flow.to(self.device).eval()
         # in case hift_model is a hifigan model
         hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
-        self.hift.load_state_dict(hift_state_dict, strict=False)
+        self.hift.load_state_dict(hift_state_dict, strict=True)
         self.hift.to(self.device).eval()
 
     def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
@@ -255,3 +255,168 @@ class CosyVoiceModel:
             self.llm_end_dict.pop(this_uuid)
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
+
+
+class CosyVoice2Model:
+
+    def __init__(self,
+                 llm: torch.nn.Module,
+                 flow: torch.nn.Module,
+                 hift: torch.nn.Module):
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.llm = llm
+        self.flow = flow
+        self.hift = hift
+        self.token_hop_len = 2 * self.flow.input_frame_rate
+        # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
+        self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
+        self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
+        # hift cache
+        self.mel_cache_len = 8
+        self.source_cache_len = int(self.mel_cache_len * 480)
+        # speech fade in out
+        self.speech_window = np.hamming(2 * self.source_cache_len)
+        # rtf and decoding related
+        self.stream_scale_factor = 1
+        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.lock = threading.Lock()
+        # dict used to store session related variable
+        self.tts_speech_token_dict = {}
+        self.llm_end_dict = {}
+        self.hift_cache_dict = {}
+
+    def load(self, llm_model, flow_model, hift_model):
+        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
+        self.llm.to(self.device).eval()
+        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
+        self.flow.to(self.device).eval()
+        self.flow.decoder.fp16 = False
+        # in case hift_model is a hifigan model
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+    def load_jit(self, flow_encoder_model):
+        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
+        self.flow.encoder = flow_encoder
+
+    def load_onnx(self, flow_decoder_estimator_model):
+        import onnxruntime
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+        del self.flow.decoder.estimator
+        self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
+
+    def load_trt(self, flow_decoder_estimator_model):
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
+        self.flow.decoder.fp16 = True
+
+    def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
+        with self.llm_context:
+            for i in self.llm.inference(text=text.to(self.device),
+                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_text=prompt_text.to(self.device),
+                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                        embedding=llm_embedding.to(self.device)):
+                self.tts_speech_token_dict[uuid].append(i)
+        self.llm_end_dict[uuid] = True
+
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
+        tts_mel, _ = self.flow.inference(token=token.to(self.device),
+                                         token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                         prompt_token=prompt_token.to(self.device),
+                                         prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                         prompt_feat=prompt_feat.to(self.device),
+                                         prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                         embedding=embedding.to(self.device),
+                                         finalize=finalize)
+        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
+        # append hift cache
+        if self.hift_cache_dict[uuid] is not None:
+            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+        else:
+            hift_cache_source = torch.zeros(1, 1, 0)
+        # keep overlap mel and hift cache
+        if finalize is False:
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
+                                          'source': tts_source[:, :, -self.source_cache_len:],
+                                          'speech': tts_speech[:, -self.source_cache_len:]}
+            tts_speech = tts_speech[:, :-self.source_cache_len]
+        else:
+            if speed != 1.0:
+                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
+                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+        return tts_speech
+
+    def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
+            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
+            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+            prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
+        # this_uuid is used to track variables related to this inference thread
+        this_uuid = str(uuid.uuid1())
+        with self.lock:
+            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
+            self.hift_cache_dict[this_uuid] = None
+        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
+        p.start()
+        if stream is True:
+            token_offset = 0
+            while True:
+                time.sleep(0.1)
+                if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
+                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]) \
+                        .unsqueeze(dim=0)
+                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                                     prompt_token=flow_prompt_speech_token,
+                                                     prompt_feat=prompt_speech_feat,
+                                                     embedding=flow_embedding,
+                                                     uuid=this_uuid,
+                                                     token_offset=token_offset,
+                                                     finalize=False)
+                    token_offset += self.token_hop_len
+                    yield {'tts_speech': this_tts_speech.cpu()}
+                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
+                    break
+            p.join()
+            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
+            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
+            this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                             prompt_token=flow_prompt_speech_token,
+                                             prompt_feat=prompt_speech_feat,
+                                             embedding=flow_embedding,
+                                             uuid=this_uuid,
+                                             token_offset=token_offset,
+                                             finalize=True)
+            yield {'tts_speech': this_tts_speech.cpu()}
+        else:
+            # deal with all tokens
+            p.join()
+            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
+            this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                             prompt_token=flow_prompt_speech_token,
+                                             prompt_feat=prompt_speech_feat,
+                                             embedding=flow_embedding,
+                                             uuid=this_uuid,
+                                             token_offset=0,
+                                             finalize=True,
+                                             speed=speed)
+            yield {'tts_speech': this_tts_speech.cpu()}
+        with self.lock:
+            self.tts_speech_token_dict.pop(this_uuid)
+            self.llm_end_dict.pop(this_uuid)

+ 91 - 11
cosyvoice/flow/decoder.py

@@ -13,16 +13,83 @@
 # limitations under the License.
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 from einops import pack, rearrange, repeat
+from cosyvoice.utils.common import mask_to_bias
+from cosyvoice.utils.mask import add_optional_chunk_mask
 from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
 from matcha.models.components.transformer import BasicTransformerBlock
 
 
+class Transpose(torch.nn.Module):
+    def __init__(self, dim0: int, dim1: int):
+        super().__init__()
+        self.dim0 = dim0
+        self.dim1 = dim1
+
+    def forward(self, x: torch.Tensor):
+        x = torch.transpose(x, self.dim0, self.dim1)
+        return x
+
+
+class CausalBlock1D(Block1D):
+    def __init__(self, dim: int, dim_out: int):
+        super(CausalBlock1D, self).__init__(dim, dim_out)
+        self.block = torch.nn.Sequential(
+            CausalConv1d(dim, dim_out, 3),
+            Transpose(1, 2),
+            nn.LayerNorm(dim_out),
+            Transpose(1, 2),
+            nn.Mish(),
+        )
+
+    def forward(self, x: torch.Tensor, mask: torch.Tensor):
+        output = self.block(x * mask)
+        return output * mask
+
+
+class CausalResnetBlock1D(ResnetBlock1D):
+    def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
+        super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
+        self.block1 = CausalBlock1D(dim, dim_out)
+        self.block2 = CausalBlock1D(dim_out, dim_out)
+
+
+class CausalConv1d(torch.nn.Conv1d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = 'zeros',
+        device=None,
+        dtype=None
+    ) -> None:
+        super(CausalConv1d, self).__init__(in_channels, out_channels,
+                                           kernel_size, stride,
+                                           padding=0, dilation=dilation,
+                                           groups=groups, bias=bias,
+                                           padding_mode=padding_mode,
+                                           device=device, dtype=dtype)
+        assert stride == 1
+        self.causal_padding = (kernel_size - 1, 0)
+
+    def forward(self, x: torch.Tensor):
+        x = F.pad(x, self.causal_padding)
+        x = super(CausalConv1d, self).forward(x)
+        return x
+
+
 class ConditionalDecoder(nn.Module):
     def __init__(
         self,
         in_channels,
         out_channels,
+        causal=False,
         channels=(256, 256),
         dropout=0.05,
         attention_head_dim=64,
@@ -39,7 +106,7 @@ class ConditionalDecoder(nn.Module):
         channels = tuple(channels)
         self.in_channels = in_channels
         self.out_channels = out_channels
-
+        self.causal = causal
         self.time_embeddings = SinusoidalPosEmb(in_channels)
         time_embed_dim = channels[0] * 4
         self.time_mlp = TimestepEmbedding(
@@ -56,7 +123,8 @@ class ConditionalDecoder(nn.Module):
             input_channel = output_channel
             output_channel = channels[i]
             is_last = i == len(channels) - 1
-            resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+            resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
+                     else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
             transformer_blocks = nn.ModuleList(
                 [
                     BasicTransformerBlock(
@@ -70,14 +138,16 @@ class ConditionalDecoder(nn.Module):
                 ]
             )
             downsample = (
-                Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+                Downsample1D(output_channel) if not is_last else \
+                CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
             )
             self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
 
         for _ in range(num_mid_blocks):
             input_channel = channels[-1]
             out_channels = channels[-1]
-            resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+            resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
+                     ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
 
             transformer_blocks = nn.ModuleList(
                 [
@@ -99,7 +169,11 @@ class ConditionalDecoder(nn.Module):
             input_channel = channels[i] * 2
             output_channel = channels[i + 1]
             is_last = i == len(channels) - 2
-            resnet = ResnetBlock1D(
+            resnet = CausalResnetBlock1D(
+                dim=input_channel,
+                dim_out=output_channel,
+                time_emb_dim=time_embed_dim,
+            ) if self.causal else ResnetBlock1D(
                 dim=input_channel,
                 dim_out=output_channel,
                 time_emb_dim=time_embed_dim,
@@ -119,10 +193,10 @@ class ConditionalDecoder(nn.Module):
             upsample = (
                 Upsample1D(output_channel, use_conv_transpose=True)
                 if not is_last
-                else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+                else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
             )
             self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
-        self.final_block = Block1D(channels[-1], channels[-1])
+        self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
         self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
         self.initialize_weights()
 
@@ -175,7 +249,9 @@ class ConditionalDecoder(nn.Module):
             mask_down = masks[-1]
             x = resnet(x, mask_down, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
+            # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
+            attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
+            attn_mask = mask_to_bias(attn_mask==1, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
                     hidden_states=x,
@@ -192,7 +268,9 @@ class ConditionalDecoder(nn.Module):
         for resnet, transformer_blocks in self.mid_blocks:
             x = resnet(x, mask_mid, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
+            # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
+            attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
+            attn_mask = mask_to_bias(attn_mask==1, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
                     hidden_states=x,
@@ -207,7 +285,9 @@ class ConditionalDecoder(nn.Module):
             x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
             x = resnet(x, mask_up, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
+            # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
+            attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
+            attn_mask = mask_to_bias(attn_mask==1, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
                     hidden_states=x,
@@ -218,4 +298,4 @@ class ConditionalDecoder(nn.Module):
             x = upsample(x * mask_up)
         x = self.final_block(x, mask_up)
         output = self.final_proj(x * mask_up)
-        return output * mask
+        return output * mask

+ 83 - 0
cosyvoice/flow/flow.py

@@ -146,3 +146,86 @@ class MaskedDiffWithXvec(torch.nn.Module):
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2
         return feat, flow_cache
+
+
+class CausalMaskedDiffWithXvec(torch.nn.Module):
+    def __init__(self,
+                 input_size: int = 512,
+                 output_size: int = 80,
+                 spk_embed_dim: int = 192,
+                 output_type: str = "mel",
+                 vocab_size: int = 4096,
+                 input_frame_rate: int = 50,
+                 only_mask_loss: bool = True,
+                 token_mel_ratio: int = 2,
+                 pre_lookahead_len: int = 3,
+                 encoder: torch.nn.Module = None,
+                 decoder: torch.nn.Module = None,
+                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
+                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
+                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
+                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
+                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
+                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
+                                        'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
+        super().__init__()
+        self.input_size = input_size
+        self.output_size = output_size
+        self.decoder_conf = decoder_conf
+        self.mel_feat_conf = mel_feat_conf
+        self.vocab_size = vocab_size
+        self.output_type = output_type
+        self.input_frame_rate = input_frame_rate
+        logging.info(f"input frame rate={self.input_frame_rate}")
+        self.input_embedding = nn.Embedding(vocab_size, input_size)
+        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
+        self.encoder = encoder
+        self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
+        self.decoder = decoder
+        self.only_mask_loss = only_mask_loss
+        self.token_mel_ratio = token_mel_ratio
+        self.pre_lookahead_len = pre_lookahead_len
+
+    @torch.inference_mode()
+    def inference(self,
+                  token,
+                  token_len,
+                  prompt_token,
+                  prompt_token_len,
+                  prompt_feat,
+                  prompt_feat_len,
+                  embedding,
+                  finalize):
+        assert token.shape[0] == 1
+        # xvec projection
+        embedding = F.normalize(embedding, dim=1)
+        embedding = self.spk_embed_affine_layer(embedding)
+
+        # concat text and prompt_text
+        token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
+        mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
+        token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+        # text encode
+        h, h_lengths = self.encoder(token, token_len)
+        if finalize is False:
+            h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
+        mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
+        h = self.encoder_proj(h)
+
+        # get conditions
+        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
+        conds[:, :mel_len1] = prompt_feat
+        conds = conds.transpose(1, 2)
+
+        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
+        feat, _ = self.decoder(
+            mu=h.transpose(1, 2).contiguous(),
+            mask=mask.unsqueeze(1),
+            spks=embedding,
+            cond=conds,
+            n_timesteps=10
+        )
+        feat = feat[:, :, mel_len1:]
+        assert feat.shape[2] == mel_len2
+        return feat, None

+ 81 - 11
cosyvoice/flow/flow_matching.py

@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import onnxruntime
 import torch
 import torch.nn.functional as F
 from matcha.models.components.flow_matching import BASECFM
@@ -88,30 +89,48 @@ class ConditionalCFM(BASECFM):
         # Or in future might add like a return_all_steps flag
         sol = []
 
+        if self.inference_cfg_rate > 0:
+            # Do not use concat, it may cause memory format changed and trt infer with wrong results!
+            x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+            mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
+            mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+            t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
+            spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
+            cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+        else:
+            x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
         for step in range(1, len(t_span)):
-            dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
             # Classifier-Free Guidance inference introduced in VoiceBox
             if self.inference_cfg_rate > 0:
-                cfg_dphi_dt = self.forward_estimator(
-                    x, mask,
-                    torch.zeros_like(mu), t,
-                    torch.zeros_like(spks) if spks is not None else None,
-                    torch.zeros_like(cond)
-                )
-                dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
-                           self.inference_cfg_rate * cfg_dphi_dt)
+                x_in[:] = x
+                mask_in[:] = mask
+                mu_in[0] = mu
+                t_in[:] = t.unsqueeze(0)
+                spks_in[0] = spks
+                cond_in[0] = cond
+            else:
+                x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
+            dphi_dt = self.forward_estimator(
+                x_in, mask_in,
+                mu_in, t_in,
+                spks_in,
+                cond_in
+            )
+            if self.inference_cfg_rate > 0:
+                dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
+                dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
             x = x + dt * dphi_dt
             t = t + dt
             sol.append(x)
             if step < len(t_span) - 1:
                 dt = t_span[step + 1] - t
 
-        return sol[-1]
+        return sol[-1].float()
 
     def forward_estimator(self, x, mask, mu, t, spks, cond):
         if isinstance(self.estimator, torch.nn.Module):
             return self.estimator.forward(x, mask, mu, t, spks, cond)
-        else:
+        elif isinstance(self.estimator, onnxruntime.InferenceSession):
             ort_inputs = {
                 'x': x.cpu().numpy(),
                 'mask': mask.cpu().numpy(),
@@ -122,6 +141,22 @@ class ConditionalCFM(BASECFM):
             }
             output = self.estimator.run(None, ort_inputs)[0]
             return torch.tensor(output, dtype=x.dtype, device=x.device)
+        else:
+            self.estimator.set_input_shape('x', (2, 80, x.size(2)))
+            self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
+            self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
+            self.estimator.set_input_shape('t', (2,))
+            self.estimator.set_input_shape('spks', (2, 80))
+            self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
+            # run trt engine
+            self.estimator.execute_v2([x.contiguous().data_ptr(),
+                                        mask.contiguous().data_ptr(),
+                                        mu.contiguous().data_ptr(),
+                                        t.contiguous().data_ptr(),
+                                        spks.contiguous().data_ptr(),
+                                        cond.contiguous().data_ptr(),
+                                        x.data_ptr()])
+            return x
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         """Computes diffusion loss
@@ -163,3 +198,38 @@ class ConditionalCFM(BASECFM):
         pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
         loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
         return loss, y
+
+
+class CausalConditionalCFM(ConditionalCFM):
+    def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
+        super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
+        self.rand_noise = torch.randn([1, 80, 50 * 300])
+
+    @torch.inference_mode()
+    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
+        """Forward diffusion
+
+        Args:
+            mu (torch.Tensor): output of encoder
+                shape: (batch_size, n_feats, mel_timesteps)
+            mask (torch.Tensor): output_mask
+                shape: (batch_size, 1, mel_timesteps)
+            n_timesteps (int): number of diffusion steps
+            temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+            spks (torch.Tensor, optional): speaker ids. Defaults to None.
+                shape: (batch_size, spk_emb_dim)
+            cond: Not used but kept for future purposes
+
+        Returns:
+            sample: generated mel-spectrogram
+                shape: (batch_size, n_feats, mel_timesteps)
+        """
+
+        z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
+        if self.fp16 is True:
+            z = z.half()
+        # fix prompt and overlap part mu and z
+        t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
+        if self.t_scheduler == 'cosine':
+            t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
+        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None

+ 125 - 0
cosyvoice/llm/llm.py

@@ -15,6 +15,7 @@ from typing import Dict, Optional, Callable, List, Generator
 import torch
 from torch import nn
 import torch.nn.functional as F
+from transformers import Qwen2ForCausalLM
 from torch.nn.utils.rnn import pad_sequence, unpad_sequence
 from cosyvoice.utils.common import IGNORE_ID
 from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
@@ -213,3 +214,127 @@ class TransformerLM(torch.nn.Module):
             out_tokens.append(top_ids)
             offset += lm_input.size(1)
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+
+
+class Qwen2Encoder(torch.nn.Module):
+    def __init__(self, pretrain_path):
+        super().__init__()
+        self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
+
+    def forward_one_step(self, xs, masks, cache=None):
+        input_masks = masks[:, -1, :]
+        outs = self.model(
+            inputs_embeds=xs,
+            attention_mask=input_masks,
+            output_hidden_states=True,
+            return_dict=True,
+            use_cache=True,
+            past_key_values=cache,
+        )
+        xs = outs.hidden_states[-1]
+        new_cache = outs.past_key_values
+        return xs, new_cache
+
+
+class Qwen2LM(torch.nn.Module):
+    def __init__(
+            self,
+            llm_input_size: int,
+            llm_output_size: int,
+            speech_token_size: int,
+            llm: torch.nn.Module,
+            sampling: Callable,
+            length_normalized_loss: bool = True,
+            lsm_weight: float = 0.0,
+    ):
+        super().__init__()
+        self.llm_input_size = llm_input_size
+        self.llm_output_size = llm_output_size
+        self.speech_token_size = speech_token_size
+
+        # 2. build speech token language model related modules
+        self.sos_eos = 0
+        self.task_id = 1
+        self.fill_token = 2
+
+        self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
+        self.llm = llm
+        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
+        self.criterion_ce = LabelSmoothingLoss(
+            size=speech_token_size + 3,
+            padding_idx=IGNORE_ID,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        # 3. [Optional] build speech token related modules
+        self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
+
+        # 4. sampling method
+        self.sampling = sampling
+
+    def sampling_ids(
+            self,
+            weighted_scores: torch.Tensor,
+            decoded_tokens: List,
+            sampling: int,
+            ignore_eos: bool = True,
+    ):
+        while True:
+            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
+            if (not ignore_eos) or (self.speech_token_size not in top_ids):
+                break
+        return top_ids
+
+    @torch.inference_mode()
+    def inference(
+            self,
+            text: torch.Tensor,
+            text_len: torch.Tensor,
+            prompt_text: torch.Tensor,
+            prompt_text_len: torch.Tensor,
+            prompt_speech_token: torch.Tensor,
+            prompt_speech_token_len: torch.Tensor,
+            embedding: torch.Tensor,
+            sampling: int = 25,
+            max_token_text_ratio: float = 20,
+            min_token_text_ratio: float = 2,
+    ) -> Generator[torch.Tensor, None, None]:
+        device = text.device
+        text = torch.concat([prompt_text, text], dim=1)
+        text_len += prompt_text_len
+        text = self.llm.model.model.embed_tokens(text)
+
+        # 2. encode embedding
+        embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+
+        # 3. concat llm_input
+        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+        if prompt_speech_token_len != 0:
+            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
+        else:
+            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+        lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
+
+        # 4. cal min/max_length
+        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
+        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
+
+        # 5. step by step decode
+        out_tokens = []
+        cache = None
+        for i in range(max_len):
+            y_pred, cache = self.llm.forward_one_step(lm_input,
+                                                      masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
+                                                      cache=cache)
+            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
+            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
+            if top_ids == self.speech_token_size:
+                break
+            if top_ids > self.speech_token_size:
+                continue
+            # in stream mode, yield token one by one
+            yield top_ids
+            out_tokens.append(top_ids)
+            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

+ 41 - 0
cosyvoice/tokenizer/tokenizer.py

@@ -2,6 +2,8 @@ import base64
 import os
 from functools import lru_cache
 from typing import Optional
+import torch
+from transformers import AutoTokenizer
 from whisper.tokenizer import Tokenizer
 
 import tiktoken
@@ -234,3 +236,42 @@ def get_tokenizer(
     return Tokenizer(
         encoding=encoding, num_languages=num_languages, language=language, task=task
     )
+
+
+class QwenTokenizer():
+    def __init__(self, token_path, skip_special_tokens=True):
+        super().__init__()
+        # NOTE: non-chat model, all these special tokens keep randomly initialized.
+        special_tokens = {
+            'eos_token': '<|endoftext|>',
+            'pad_token': '<|endoftext|>',
+            'additional_special_tokens': [
+                '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
+                '[breath]', '<strong>', '</strong>', '[noise]',
+                '[laughter]', '[cough]', '[clucking]', '[accent]',
+                '[quick_breath]',
+                "<laughter>", "</laughter>",
+                "[hissing]", "[sigh]", "[vocalized-noise]",
+                "[lipsmack]", "[mn]"
+            ]
+        }
+        self.tokenizer = AutoTokenizer.from_pretrained(token_path)
+        self.tokenizer.add_special_tokens(special_tokens)
+        self.skip_special_tokens = skip_special_tokens
+
+    def encode(self, text, **kwargs):
+        tokens = self.tokenizer([text], return_tensors="pt")
+        tokens = tokens["input_ids"][0].cpu().tolist()
+        return tokens
+
+    def decode(self, tokens):
+        tokens = torch.tensor(tokens, dtype=torch.int64)
+        text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
+        return text
+
+@lru_cache(maxsize=None)
+def get_qwen_tokenizer(
+    token_path: str,
+    skip_special_tokens: bool
+) -> QwenTokenizer:
+    return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)

+ 7 - 7
cosyvoice/transformer/encoder_layer.py

@@ -49,8 +49,8 @@ class TransformerEncoderLayer(nn.Module):
         super().__init__()
         self.self_attn = self_attn
         self.feed_forward = feed_forward
-        self.norm1 = nn.LayerNorm(size, eps=1e-5)
-        self.norm2 = nn.LayerNorm(size, eps=1e-5)
+        self.norm1 = nn.LayerNorm(size, eps=1e-12)
+        self.norm2 = nn.LayerNorm(size, eps=1e-12)
         self.dropout = nn.Dropout(dropout_rate)
         self.size = size
         self.normalize_before = normalize_before
@@ -142,17 +142,17 @@ class ConformerEncoderLayer(nn.Module):
         self.feed_forward = feed_forward
         self.feed_forward_macaron = feed_forward_macaron
         self.conv_module = conv_module
-        self.norm_ff = nn.LayerNorm(size, eps=1e-5)  # for the FNN module
-        self.norm_mha = nn.LayerNorm(size, eps=1e-5)  # for the MHA module
+        self.norm_ff = nn.LayerNorm(size, eps=1e-12)  # for the FNN module
+        self.norm_mha = nn.LayerNorm(size, eps=1e-12)  # for the MHA module
         if feed_forward_macaron is not None:
-            self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
+            self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
             self.ff_scale = 0.5
         else:
             self.ff_scale = 1.0
         if self.conv_module is not None:
-            self.norm_conv = nn.LayerNorm(size, eps=1e-5)  # for the CNN module
+            self.norm_conv = nn.LayerNorm(size, eps=1e-12)  # for the CNN module
             self.norm_final = nn.LayerNorm(
-                size, eps=1e-5)  # for the final output of the block
+                size, eps=1e-12)  # for the final output of the block
         self.dropout = nn.Dropout(dropout_rate)
         self.size = size
         self.normalize_before = normalize_before

+ 321 - 0
cosyvoice/transformer/upsample_encoder.py

@@ -0,0 +1,321 @@
+# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
+#               2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
+#               2024 Alibaba Inc (Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from ESPnet(https://github.com/espnet/espnet)
+"""Encoder definition."""
+from typing import Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from cosyvoice.transformer.convolution import ConvolutionModule
+from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
+from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from cosyvoice.utils.class_utils import (
+    COSYVOICE_EMB_CLASSES,
+    COSYVOICE_SUBSAMPLE_CLASSES,
+    COSYVOICE_ATTENTION_CLASSES,
+    COSYVOICE_ACTIVATION_CLASSES,
+)
+from cosyvoice.utils.mask import make_pad_mask
+from cosyvoice.utils.mask import add_optional_chunk_mask
+
+
+class Upsample1D(nn.Module):
+    """A 1D upsampling layer with an optional convolution.
+
+    Parameters:
+        channels (`int`):
+            number of channels in the inputs and outputs.
+        use_conv (`bool`, default `False`):
+            option to use a convolution.
+        use_conv_transpose (`bool`, default `False`):
+            option to use a convolution transpose.
+        out_channels (`int`, optional):
+            number of output channels. Defaults to `channels`.
+    """
+
+    def __init__(self, channels: int, out_channels: int, stride: int = 2):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels
+        self.stride = stride
+        # In this mode, first repeat interpolate, than conv with stride=1
+        self.conv = nn.Conv1d(
+            self.channels, self.out_channels, stride * 2 + 1, stride = 1,
+            padding=0,
+        )
+
+    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
+        outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
+        outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
+        outputs = self.conv(outputs)
+        return outputs, input_lengths * self.stride
+
+
+class PreLookaheadLayer(nn.Module):
+    def __init__(self, channels: int, pre_lookahead_len: int = 1):
+        super().__init__()
+        self.channels = channels
+        self.pre_lookahead_len = pre_lookahead_len
+        self.conv1 = nn.Conv1d(
+            channels, channels,
+            kernel_size=pre_lookahead_len + 1,
+            stride=1, padding=0,
+        )
+        self.conv2 = nn.Conv1d(
+            channels, channels,
+            kernel_size=3, stride=1, padding=0,
+        )
+
+    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+        """
+        inputs: (batch_size, seq_len, channels)
+        """
+        outputs = inputs.transpose(1, 2).contiguous()
+        # look ahead
+        outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
+        outputs = F.leaky_relu(self.conv1(outputs))
+        # outputs
+        outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
+        outputs = self.conv2(outputs)
+        outputs = outputs.transpose(1, 2).contiguous()
+
+        # residual connection
+        outputs = outputs + inputs
+        return outputs
+
+
+class UpsampleConformerEncoder(torch.nn.Module):
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: str = "conv2d",
+        pos_enc_layer_type: str = "rel_pos",
+        normalize_before: bool = True,
+        static_chunk_size: int = 0,
+        use_dynamic_chunk: bool = False,
+        global_cmvn: torch.nn.Module = None,
+        use_dynamic_left_chunk: bool = False,
+        positionwise_conv_kernel_size: int = 1,
+        macaron_style: bool = True,
+        selfattention_layer_type: str = "rel_selfattn",
+        activation_type: str = "swish",
+        use_cnn_module: bool = True,
+        cnn_module_kernel: int = 15,
+        causal: bool = False,
+        cnn_module_norm: str = "batch_norm",
+        key_bias: bool = True,
+        gradient_checkpointing: bool = False,
+    ):
+        """
+        Args:
+            input_size (int): input dim
+            output_size (int): dimension of attention
+            attention_heads (int): the number of heads of multi head attention
+            linear_units (int): the hidden units number of position-wise feed
+                forward
+            num_blocks (int): the number of decoder blocks
+            dropout_rate (float): dropout rate
+            attention_dropout_rate (float): dropout rate in attention
+            positional_dropout_rate (float): dropout rate after adding
+                positional encoding
+            input_layer (str): input layer type.
+                optional [linear, conv2d, conv2d6, conv2d8]
+            pos_enc_layer_type (str): Encoder positional encoding layer type.
+                opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
+            normalize_before (bool):
+                True: use layer_norm before each sub-block of a layer.
+                False: use layer_norm after each sub-block of a layer.
+            static_chunk_size (int): chunk size for static chunk training and
+                decoding
+            use_dynamic_chunk (bool): whether use dynamic chunk size for
+                training or not, You can only use fixed chunk(chunk_size > 0)
+                or dyanmic chunk size(use_dynamic_chunk = True)
+            global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
+            use_dynamic_left_chunk (bool): whether use dynamic left chunk in
+                dynamic chunk training
+            key_bias: whether use bias in attention.linear_k, False for whisper models.
+            gradient_checkpointing: rerunning a forward-pass segment for each
+                checkpointed segment during backward.
+        """
+        super().__init__()
+        self._output_size = output_size
+
+        self.global_cmvn = global_cmvn
+        self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
+            input_size,
+            output_size,
+            dropout_rate,
+            COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
+                                                      positional_dropout_rate),
+        )
+
+        self.normalize_before = normalize_before
+        self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
+        self.static_chunk_size = static_chunk_size
+        self.use_dynamic_chunk = use_dynamic_chunk
+        self.use_dynamic_left_chunk = use_dynamic_left_chunk
+        self.gradient_checkpointing = gradient_checkpointing
+        activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
+        # self-attention module definition
+        encoder_selfattn_layer_args = (
+            attention_heads,
+            output_size,
+            attention_dropout_rate,
+            key_bias,
+        )
+        # feed-forward module definition
+        positionwise_layer_args = (
+            output_size,
+            linear_units,
+            dropout_rate,
+            activation,
+        )
+        # convolution module definition
+        convolution_layer_args = (output_size, cnn_module_kernel, activation,
+                                  cnn_module_norm, causal)
+        self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
+        self.encoders = torch.nn.ModuleList([
+            ConformerEncoderLayer(
+                output_size,
+                COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
+                    *encoder_selfattn_layer_args),
+                PositionwiseFeedForward(*positionwise_layer_args),
+                PositionwiseFeedForward(
+                    *positionwise_layer_args) if macaron_style else None,
+                ConvolutionModule(
+                    *convolution_layer_args) if use_cnn_module else None,
+                dropout_rate,
+                normalize_before,
+            ) for _ in range(num_blocks)
+        ])
+        self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
+        self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
+            input_size,
+            output_size,
+            dropout_rate,
+            COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
+                                                      positional_dropout_rate),
+        )
+        self.up_encoders = torch.nn.ModuleList([
+            ConformerEncoderLayer(
+                output_size,
+                COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
+                    *encoder_selfattn_layer_args),
+                PositionwiseFeedForward(*positionwise_layer_args),
+                PositionwiseFeedForward(
+                    *positionwise_layer_args) if macaron_style else None,
+                ConvolutionModule(
+                    *convolution_layer_args) if use_cnn_module else None,
+                dropout_rate,
+                normalize_before,
+            ) for _ in range(4)
+        ])
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs: torch.Tensor,
+        xs_lens: torch.Tensor,
+        decoding_chunk_size: int = 0,
+        num_decoding_left_chunks: int = -1,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Embed positions in tensor.
+
+        Args:
+            xs: padded input tensor (B, T, D)
+            xs_lens: input length (B)
+            decoding_chunk_size: decoding chunk size for dynamic chunk
+                0: default for training, use random dynamic chunk.
+                <0: for decoding, use full chunk.
+                >0: for decoding, use fixed chunk size as set.
+            num_decoding_left_chunks: number of left chunks, this is for decoding,
+            the chunk size is decoding_chunk_size.
+                >=0: use num_decoding_left_chunks
+                <0: use all left chunks
+        Returns:
+            encoder output tensor xs, and subsampled masks
+            xs: padded output tensor (B, T' ~= T/subsample_rate, D)
+            masks: torch.Tensor batch padding mask after subsample
+                (B, 1, T' ~= T/subsample_rate)
+        NOTE(xcsong):
+            We pass the `__call__` method of the modules instead of `forward` to the
+            checkpointing API because `__call__` attaches all the hooks of the module.
+            https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+        """
+        T = xs.size(1)
+        masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
+        if self.global_cmvn is not None:
+            xs = self.global_cmvn(xs)
+        xs, pos_emb, masks = self.embed(xs, masks)
+        mask_pad = masks  # (B, 1, T/subsample_rate)
+        chunk_masks = add_optional_chunk_mask(xs, masks,
+                                              self.use_dynamic_chunk,
+                                              self.use_dynamic_left_chunk,
+                                              decoding_chunk_size,
+                                              self.static_chunk_size,
+                                              num_decoding_left_chunks)
+        # lookahead + conformer encoder
+        xs = self.pre_lookahead_layer(xs)
+        xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
+
+        # upsample + conformer encoder
+        xs = xs.transpose(1, 2).contiguous()
+        xs, xs_lens = self.up_layer(xs, xs_lens)
+        xs = xs.transpose(1, 2).contiguous()
+        T = xs.size(1)
+        masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
+        xs, pos_emb, masks = self.up_embed(xs, masks)
+        mask_pad = masks  # (B, 1, T/subsample_rate)
+        chunk_masks = add_optional_chunk_mask(xs, masks,
+                                              self.use_dynamic_chunk,
+                                              self.use_dynamic_left_chunk,
+                                              decoding_chunk_size,
+                                              self.static_chunk_size * self.up_layer.stride,
+                                              num_decoding_left_chunks)
+        xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
+
+        if self.normalize_before:
+            xs = self.after_norm(xs)
+        # Here we assume the mask is not changed in encoder layers, so just
+        # return the masks before encoder layers, and the masks will be used
+        # for cross attention with decoder later
+        return xs, masks
+
+    def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
+                       pos_emb: torch.Tensor,
+                       mask_pad: torch.Tensor) -> torch.Tensor:
+        for layer in self.encoders:
+            xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+        return xs
+
+    def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
+                          pos_emb: torch.Tensor,
+                          mask_pad: torch.Tensor) -> torch.Tensor:
+        for layer in self.up_encoders:
+            xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+        return xs

+ 11 - 0
cosyvoice/utils/common.py

@@ -153,3 +153,14 @@ def set_all_random_seed(seed):
     np.random.seed(seed)
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
+
+
+def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+    assert mask.dtype == torch.bool
+    assert dtype in [torch.float32, torch.bfloat16, torch.float16]
+    mask = mask.to(dtype)
+    # attention mask bias
+    # NOTE(Mddct): torch.finfo jit issues
+    #     chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
+    mask = (1.0 - mask) * torch.finfo(dtype).min
+    return mask

+ 10 - 5
requirements.txt

@@ -1,4 +1,5 @@
---extra-index-url https://download.pytorch.org/whl/cu118
+--extra-index-url https://download.pytorch.org/whl/cu121
+--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
 conformer==0.3.2
 deepspeed==0.14.2; sys_platform == 'linux'
 diffusers==0.27.2
@@ -17,16 +18,20 @@ modelscope==1.15.0
 networkx==3.1
 omegaconf==2.3.0
 onnx==1.16.0
-onnxruntime-gpu==1.16.0; sys_platform == 'linux'
-onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
+onnxruntime-gpu==1.18.0; sys_platform == 'linux'
+onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'windows'
 openai-whisper==20231117
 protobuf==4.25
 pydantic==2.7.0
 rich==13.7.1
 soundfile==0.12.1
 tensorboard==2.14.0
-torch==2.0.1
-torchaudio==2.0.2
+tensorrt-cu12==10.0.1
+tensorrt-cu12-bindings==10.0.1
+tensorrt-cu12-libs==10.0.1
+torch==2.3.1
+torchaudio==2.3.1
+transformers==4.40.1
 uvicorn==0.30.0
 wget==3.2
 fastapi==0.111.0

+ 17 - 17
webui.py

@@ -22,7 +22,7 @@ import random
 import librosa
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
 from cosyvoice.utils.file_utils import load_wav, logging
 from cosyvoice.utils.common import set_all_random_seed
 
@@ -51,7 +51,7 @@ def postprocess(speech, top_db=60, hop_length=220, win_length=440):
     )
     if speech.abs().max() > max_val:
         speech = speech / speech.abs().max() * max_val
-    speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
+    speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
     return speech
 
 
@@ -71,31 +71,31 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
     if mode_checkbox_group in ['自然语言控制']:
         if cosyvoice.frontend.instruct is False:
             gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
-            yield (target_sr, default_data)
+            yield (cosyvoice.sample_rate, default_data)
         if instruct_text == '':
             gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
-            yield (target_sr, default_data)
+            yield (cosyvoice.sample_rate, default_data)
         if prompt_wav is not None or prompt_text != '':
             gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
     # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
     if mode_checkbox_group in ['跨语种复刻']:
         if cosyvoice.frontend.instruct is True:
             gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
-            yield (target_sr, default_data)
+            yield (cosyvoice.sample_rate, default_data)
         if instruct_text != '':
             gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
         if prompt_wav is None:
             gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
-            yield (target_sr, default_data)
+            yield (cosyvoice.sample_rate, default_data)
         gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
     # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
     if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
         if prompt_wav is None:
             gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
-            yield (target_sr, default_data)
+            yield (cosyvoice.sample_rate, default_data)
         if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
             gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
-            yield (target_sr, default_data)
+            yield (cosyvoice.sample_rate, default_data)
     # sft mode only use sft_dropdown
     if mode_checkbox_group in ['预训练音色']:
         if instruct_text != '' or prompt_wav is not None or prompt_text != '':
@@ -104,7 +104,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
     if mode_checkbox_group in ['3s极速复刻']:
         if prompt_text == '':
             gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
-            yield (target_sr, default_data)
+            yield (cosyvoice.sample_rate, default_data)
         if instruct_text != '':
             gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
 
@@ -112,24 +112,24 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
         logging.info('get sft inference request')
         set_all_random_seed(seed)
         for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
-            yield (target_sr, i['tts_speech'].numpy().flatten())
+            yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '3s极速复刻':
         logging.info('get zero_shot inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
         for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
-            yield (target_sr, i['tts_speech'].numpy().flatten())
+            yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '跨语种复刻':
         logging.info('get cross_lingual inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
         for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
-            yield (target_sr, i['tts_speech'].numpy().flatten())
+            yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
     else:
         logging.info('get instruct inference request')
         set_all_random_seed(seed)
         for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
-            yield (target_sr, i['tts_speech'].numpy().flatten())
+            yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
 
 
 def main():
@@ -178,11 +178,11 @@ if __name__ == '__main__':
                         default=8000)
     parser.add_argument('--model_dir',
                         type=str,
-                        default='pretrained_models/CosyVoice-300M',
+                        default='pretrained_models/CosyVoice2-0.5B',
                         help='local path or modelscope repo id')
     args = parser.parse_args()
-    cosyvoice = CosyVoice(args.model_dir)
+    cosyvoice = CosyVoice2(args.model_dir) if 'CosyVoice2' in args.model_dir else CosyVoice(args.model_dir)
     sft_spk = cosyvoice.list_avaliable_spks()
-    prompt_sr, target_sr = 16000, 22050
-    default_data = np.zeros(target_sr)
+    prompt_sr = 16000
+    default_data = np.zeros(cosyvoice.sample_rate)
     main()