Browse Source

Merge pull request #327 from FunAudioLLM/inference_streaming

Inference streaming
Xiang Lyu 1 năm trước cách đây
mục cha
commit
20e0715dac

+ 2 - 0
.gitignore

@@ -43,6 +43,8 @@ compile_commands.json
 
 # train/inference files
 *.wav
+*.m4a
+*.aac
 *.pt
 pretrained_models/*
 *_pb2_grpc.py

+ 11 - 10
README.md

@@ -116,23 +116,24 @@ import torchaudio
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
 # sft usage
 print(cosyvoice.list_avaliable_spks())
-output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女')
-torchaudio.save('sft.wav', output['tts_speech'], 22050)
+# change stream=True for chunk stream inference
+for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
+    torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
 
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
 # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
 prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
-output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
-torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
+for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
 # cross_lingual usage
 prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
-output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
-torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
+for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
+    torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
 
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
 # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
-output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
-torchaudio.save('instruct.wav', output['tts_speech'], 22050)
+for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
+    torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
 ```
 
 **Start web demo**
@@ -163,10 +164,10 @@ docker build -t cosyvoice:v1.0 .
 # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
 # for grpc usage
 docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
-python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
+cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
 # for fastapi usage
 docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
-python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
+cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
 ```
 
 ## Discussion & Communication

+ 64 - 0
cosyvoice/bin/export_jit.py

@@ -0,0 +1,64 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+import torch
+from cosyvoice.cli.cosyvoice import CosyVoice
+
+def get_args():
+    parser = argparse.ArgumentParser(description='export your model for deployment')
+    parser.add_argument('--model_dir',
+                        type=str,
+                        default='pretrained_models/CosyVoice-300M',
+                        help='local path')
+    args = parser.parse_args()
+    print(args)
+    return args
+
+def main():
+    args = get_args()
+    logging.basicConfig(level=logging.DEBUG,
+                        format='%(asctime)s %(levelname)s %(message)s')
+
+    torch._C._jit_set_fusion_strategy([('STATIC', 1)])
+    torch._C._jit_set_profiling_mode(False)
+    torch._C._jit_set_profiling_executor(False)
+
+    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
+
+    # 1. export llm text_encoder
+    llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
+    script = torch.jit.script(llm_text_encoder)
+    script = torch.jit.freeze(script)
+    script = torch.jit.optimize_for_inference(script)
+    script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
+
+    # 2. export llm llm
+    llm_llm = cosyvoice.model.llm.llm.half()
+    script = torch.jit.script(llm_llm)
+    script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
+    script = torch.jit.optimize_for_inference(script)
+    script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
+
+if __name__ == '__main__':
+    main()

+ 8 - 0
cosyvoice/bin/export_trt.py

@@ -0,0 +1,8 @@
+# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
+# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
+try:
+    import tensorrt
+except ImportError:
+    print('step1, 下载\n step2. 解压,安装whl,')
+# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
+# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx

+ 5 - 2
cosyvoice/bin/inference.py

@@ -100,10 +100,13 @@ def main():
                                'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
                                'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
                                'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
-            model_output = model.inference(**model_input)
+            tts_speeches = []
+            for model_output in model.inference(**model_input):
+                tts_speeches.append(model_output['tts_speech'])
+            tts_speeches = torch.concat(tts_speeches, dim=1)
             tts_key = '{}_{}'.format(utts[0], tts_index[0])
             tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
-            torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
+            torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
             f.write('{} {}\n'.format(tts_key, tts_fn))
             f.flush()
     f.close()

+ 38 - 22
cosyvoice/cli/cosyvoice.py

@@ -12,15 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import os
-import torch
+import time
 from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 from cosyvoice.cli.model import CosyVoiceModel
+from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
-    def __init__(self, model_dir):
+    def __init__(self, model_dir, load_jit=True):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
@@ -38,46 +39,61 @@ class CosyVoice:
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
+        if load_jit:
+            self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
+                                    '{}/llm.llm.fp16.zip'.format(model_dir))
         del configs
 
     def list_avaliable_spks(self):
         spks = list(self.frontend.spk2info.keys())
         return spks
 
-    def inference_sft(self, tts_text, spk_id):
-        tts_speeches = []
+    def inference_sft(self, tts_text, spk_id, stream=False):
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_sft(i, spk_id)
-            model_output = self.model.inference(**model_input)
-            tts_speeches.append(model_output['tts_speech'])
-        return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+            start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
+            for model_output in self.model.inference(**model_input, stream=stream):
+                speech_len = model_output['tts_speech'].shape[1] / 22050
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()
 
-    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
+    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
         prompt_text = self.frontend.text_normalize(prompt_text, split=False)
-        tts_speeches = []
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
-            model_output = self.model.inference(**model_input)
-            tts_speeches.append(model_output['tts_speech'])
-        return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+            start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
+            for model_output in self.model.inference(**model_input, stream=stream):
+                speech_len = model_output['tts_speech'].shape[1] / 22050
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()
 
-    def inference_cross_lingual(self, tts_text, prompt_speech_16k):
+    def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
         if self.frontend.instruct is True:
             raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
-        tts_speeches = []
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
-            model_output = self.model.inference(**model_input)
-            tts_speeches.append(model_output['tts_speech'])
-        return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+            start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
+            for model_output in self.model.inference(**model_input, stream=stream):
+                speech_len = model_output['tts_speech'].shape[1] / 22050
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()
 
-    def inference_instruct(self, tts_text, spk_id, instruct_text):
+    def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
         if self.frontend.instruct is False:
             raise ValueError('{} do not support instruct inference'.format(self.model_dir))
         instruct_text = self.frontend.text_normalize(instruct_text, split=False)
-        tts_speeches = []
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
-            model_output = self.model.inference(**model_input)
-            tts_speeches.append(model_output['tts_speech'])
-        return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+            start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
+            for model_output in self.model.inference(**model_input, stream=stream):
+                speech_len = model_output['tts_speech'].shape[1] / 22050
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()

+ 138 - 26
cosyvoice/cli/model.py

@@ -12,6 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import torch
+import numpy as np
+import threading
+import time
+from contextlib import nullcontext
+import uuid
+from cosyvoice.utils.common import fade_in_out
+
 
 class CosyVoiceModel:
 
@@ -23,38 +30,143 @@ class CosyVoiceModel:
         self.llm = llm
         self.flow = flow
         self.hift = hift
+        self.token_min_hop_len = 100
+        self.token_max_hop_len = 200
+        self.token_overlap_len = 20
+        # mel fade in out
+        self.mel_overlap_len = 34
+        self.mel_window = np.hamming(2 * self.mel_overlap_len)
+        # hift cache
+        self.mel_cache_len = 20
+        self.source_cache_len = int(self.mel_cache_len * 256)
+        # rtf and decoding related
+        self.stream_scale_factor = 1
+        assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
+        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.lock = threading.Lock()
+        # dict used to store session related variable
+        self.tts_speech_token_dict = {}
+        self.llm_end_dict = {}
+        self.mel_overlap_dict = {}
+        self.hift_cache_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
         self.llm.to(self.device).eval()
+        self.llm.half()
         self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
         self.flow.to(self.device).eval()
         self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
         self.hift.to(self.device).eval()
 
-    def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
-                  prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
-                  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
-                  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
-                  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
-        tts_speech_token = self.llm.inference(text=text.to(self.device),
-                                              text_len=text_len.to(self.device),
-                                              prompt_text=prompt_text.to(self.device),
-                                              prompt_text_len=prompt_text_len.to(self.device),
-                                              prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                              prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
-                                              embedding=llm_embedding.to(self.device),
-                                              beam_size=1,
-                                              sampling=25,
-                                              max_token_text_ratio=30,
-                                              min_token_text_ratio=3)
-        tts_mel = self.flow.inference(token=tts_speech_token,
-                                      token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
-                                      prompt_token=flow_prompt_speech_token.to(self.device),
-                                      prompt_token_len=flow_prompt_speech_token_len.to(self.device),
-                                      prompt_feat=prompt_speech_feat.to(self.device),
-                                      prompt_feat_len=prompt_speech_feat_len.to(self.device),
-                                      embedding=flow_embedding.to(self.device))
-        tts_speech = self.hift.inference(mel=tts_mel).cpu()
-        torch.cuda.empty_cache()
-        return {'tts_speech': tts_speech}
+    def load_jit(self, llm_text_encoder_model, llm_llm_model):
+        llm_text_encoder = torch.jit.load(llm_text_encoder_model)
+        self.llm.text_encoder = llm_text_encoder
+        llm_llm = torch.jit.load(llm_llm_model)
+        self.llm.llm = llm_llm
+
+    def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
+        with self.llm_context:
+            for i in self.llm.inference(text=text.to(self.device),
+                                                text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+                                                prompt_text=prompt_text.to(self.device),
+                                                prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                                prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                                prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                embedding=llm_embedding.to(self.device).half(),
+                                                sampling=25,
+                                                max_token_text_ratio=30,
+                                                min_token_text_ratio=3):
+                self.tts_speech_token_dict[uuid].append(i)
+        self.llm_end_dict[uuid] = True
+
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
+        with self.flow_hift_context:
+            tts_mel = self.flow.inference(token=token.to(self.device),
+                                        token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_token=prompt_token.to(self.device),
+                                        prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_feat=prompt_feat.to(self.device),
+                                        prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                        embedding=embedding.to(self.device))
+            # mel overlap fade in out
+            if self.mel_overlap_dict[uuid] is not None:
+                tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
+            # append hift cache
+            if self.hift_cache_dict[uuid] is not None:
+                hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+                tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+            else:
+                hift_cache_source = torch.zeros(1, 1, 0)
+            # keep overlap mel and hift cache
+            if finalize is False:
+                self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
+                tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
+                tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
+                self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
+                tts_speech = tts_speech[:, :-self.source_cache_len]
+            else:
+                tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
+        return tts_speech
+
+    def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
+                  prompt_text=torch.zeros(1, 0, dtype=torch.int32),
+                  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+                  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+                  prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
+        # this_uuid is used to track variables related to this inference thread
+        this_uuid = str(uuid.uuid1())
+        with self.lock:
+            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
+        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
+        p.start()
+        if stream is True:
+            token_hop_len = self.token_min_hop_len
+            while True:
+                time.sleep(0.1)
+                if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
+                    this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
+                    with self.flow_hift_context:
+                        this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                                    prompt_token=flow_prompt_speech_token,
+                                                    prompt_feat=prompt_speech_feat,
+                                                    embedding=flow_embedding,
+                                                    uuid=this_uuid,
+                                                    finalize=False)
+                    yield  {'tts_speech': this_tts_speech.cpu()}
+                    with self.lock:
+                        self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
+                    # increase token_hop_len for better speech quality
+                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
+                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
+                    break
+            p.join()
+            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
+            this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
+            with self.flow_hift_context:
+                this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                            prompt_token=flow_prompt_speech_token,
+                                            prompt_feat=prompt_speech_feat,
+                                            embedding=flow_embedding,
+                                            uuid=this_uuid,
+                                            finalize=True)
+            yield {'tts_speech': this_tts_speech.cpu()}
+        else:
+            # deal with all tokens
+            p.join()
+            this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
+            with self.flow_hift_context:
+                this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                            prompt_token=flow_prompt_speech_token,
+                                            prompt_feat=prompt_speech_feat,
+                                            embedding=flow_embedding,
+                                            uuid=this_uuid,
+                                            finalize=True)
+            yield {'tts_speech': this_tts_speech.cpu()}
+        with self.lock:
+            self.tts_speech_token_dict.pop(this_uuid)
+            self.llm_end_dict.pop(this_uuid)
+            self.mel_overlap_dict.pop(this_uuid)
+            self.hift_cache_dict.pop(this_uuid)
+        torch.cuda.synchronize()

+ 9 - 9
cosyvoice/flow/flow.py

@@ -111,6 +111,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         embedding = self.spk_embed_affine_layer(embedding)
 
         # concat text and prompt_text
+        token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
         token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
         mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
         token = self.input_embedding(torch.clamp(token, min=0)) * mask
@@ -118,17 +119,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
         # text encode
         h, h_lengths = self.encoder(token, token_len)
         h = self.encoder_proj(h)
-        feat_len = (token_len / 50 * 22050 / 256).int()
-        h, h_lengths = self.length_regulator(h, feat_len)
+        mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
+        h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
 
         # get conditions
-        conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
-        if prompt_feat.shape[1] != 0:
-            for i, j in enumerate(prompt_feat_len):
-                conds[i, :j] = prompt_feat[i]
+        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
+        conds[:, :mel_len1] = prompt_feat
         conds = conds.transpose(1, 2)
 
-        mask = (~make_pad_mask(feat_len)).to(h)
+        # mask = (~make_pad_mask(feat_len)).to(h)
+        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
         feat = self.decoder(
             mu=h.transpose(1, 2).contiguous(),
             mask=mask.unsqueeze(1),
@@ -136,6 +136,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
             cond=conds,
             n_timesteps=10
         )
-        if prompt_feat.shape[1] != 0:
-            feat = feat[:, :, prompt_feat.shape[1]:]
+        feat = feat[:, :, mel_len1:]
+        assert feat.shape[2] == mel_len2
         return feat

+ 20 - 1
cosyvoice/flow/length_regulator.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 from typing import Tuple
 import torch.nn as nn
+import torch
 from torch.nn import functional as F
 from cosyvoice.utils.mask import make_pad_mask
 
@@ -43,7 +44,25 @@ class InterpolateRegulator(nn.Module):
     def forward(self, x, ylens=None):
         # x in (B, T, D)
         mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
-        x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+        x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
         out = self.model(x).transpose(1, 2).contiguous()
         olens = ylens
         return out * mask, olens
+
+    def inference(self, x1, x2, mel_len1, mel_len2):
+        # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
+        # x in (B, T, D)
+        if x2.shape[1] > 40:
+            x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
+            x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
+            x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
+            x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
+        else:
+            x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
+        if x1.shape[1] != 0:
+            x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
+            x = torch.concat([x1, x2], dim=2)
+        else:
+            x = x2
+        out = self.model(x).transpose(1, 2).contiguous()
+        return out, mel_len1 + mel_len2

+ 8 - 4
cosyvoice/hifigan/generator.py

@@ -335,10 +335,14 @@ class HiFTGenerator(nn.Module):
         inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
         return inverse_transform
 
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
+    def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
         f0 = self.f0_predictor(x)
         s = self._f02source(f0)
 
+        # use cache_source to avoid glitch
+        if cache_source.shape[2] == 0:
+            s[:, :, :cache_source.shape[2]] = cache_source
+
         s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
         s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
 
@@ -370,7 +374,7 @@ class HiFTGenerator(nn.Module):
 
         x = self._istft(magnitude, phase)
         x = torch.clamp(x, -self.audio_limit, self.audio_limit)
-        return x
+        return x, s
 
     def remove_weight_norm(self):
         print('Removing weight norm...')
@@ -387,5 +391,5 @@ class HiFTGenerator(nn.Module):
             l.remove_weight_norm()
 
     @torch.inference_mode()
-    def inference(self, mel: torch.Tensor) -> torch.Tensor:
-        return self.forward(x=mel)
+    def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+        return self.forward(x=mel, cache_source=cache_source)

+ 14 - 13
cosyvoice/llm/llm.py

@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Callable, List, Generator
 import torch
 from torch import nn
 import torch.nn.functional as F
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
             speech_token_size: int,
             text_encoder: torch.nn.Module,
             llm: torch.nn.Module,
+            sampling: Callable,
             length_normalized_loss: bool = True,
             lsm_weight: float = 0.0,
             spk_embed_dim: int = 192,
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
         self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
         self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
 
+        # 4. sampling method
+        self.sampling = sampling
+
     def encode(
             self,
             text: torch.Tensor,
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
     def sampling_ids(
             self,
             weighted_scores: torch.Tensor,
-            sampling: Union[bool, int, float] = True,
-            beam_size: int = 1,
+            decoded_tokens: List,
+            sampling: int,
             ignore_eos: bool = True,
     ):
         while True:
-            prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
-            top_ids = prob.multinomial(beam_size, replacement=True)
-            top_ids = indices[top_ids]
+            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
             if (not ignore_eos) or (self.speech_token_size not in top_ids):
                 break
         return top_ids
@@ -154,11 +156,10 @@ class TransformerLM(torch.nn.Module):
             prompt_speech_token: torch.Tensor,
             prompt_speech_token_len: torch.Tensor,
             embedding: torch.Tensor,
-            beam_size: int = 1,
             sampling: int = 25,
             max_token_text_ratio: float = 20,
             min_token_text_ratio: float = 2,
-    ) -> torch.Tensor:
+    ) -> Generator[torch.Tensor, None, None]:
         device = text.device
         text = torch.concat([prompt_text, text], dim=1)
         text_len += prompt_text_len
@@ -173,7 +174,7 @@ class TransformerLM(torch.nn.Module):
             embedding = self.spk_embed_affine_layer(embedding)
             embedding = embedding.unsqueeze(dim=1)
         else:
-            embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
+            embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
 
         # 3. concat llm_input
         sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
@@ -181,7 +182,7 @@ class TransformerLM(torch.nn.Module):
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
-            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
+            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
         lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
 
         # 4. cal min/max_length
@@ -196,11 +197,11 @@ class TransformerLM(torch.nn.Module):
             y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
                                                                   att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
             logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
+            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
             if top_ids == self.speech_token_size:
                 break
+            # in stream mode, yield token one by one
+            yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
             out_tokens.append(top_ids)
             offset += lm_input.size(1)
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
-
-        return torch.tensor([out_tokens], dtype=torch.int64, device=device)

+ 7 - 3
cosyvoice/transformer/attention.py

@@ -222,7 +222,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
         torch.nn.init.xavier_uniform_(self.pos_bias_u)
         torch.nn.init.xavier_uniform_(self.pos_bias_v)
 
-    def rel_shift(self, x):
+    def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
         """Compute relative positional encoding.
 
         Args:
@@ -233,10 +233,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
             torch.Tensor: Output tensor.
 
         """
-        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
+                               device=x.device,
+                               dtype=x.dtype)
         x_padded = torch.cat([zero_pad, x], dim=-1)
 
-        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x_padded = x_padded.view(x.size()[0],
+                                 x.size()[1],
+                                 x.size(3) + 1, x.size(2))
         x = x_padded[:, :, 1:].view_as(x)[
             :, :, :, : x.size(-1) // 2 + 1
         ]  # only keep the positions from 0 to time2

+ 1 - 1
cosyvoice/transformer/decoder.py

@@ -174,7 +174,7 @@ class TransformerDecoder(torch.nn.Module):
                                                      memory_mask)
         return x
 
-    @torch.jit.ignore(drop=True)
+    @torch.jit.unused
     def forward_layers_checkpointed(self, x: torch.Tensor,
                                     tgt_mask: torch.Tensor,
                                     memory: torch.Tensor,

+ 4 - 3
cosyvoice/transformer/embedding.py

@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model, dropout_rate, max_len=5000):
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
         """Construct an PositionalEncoding object."""
         super(EspnetRelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -221,7 +221,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
         self.pe = None
         self.extend_pe(torch.tensor(0.0).expand(1, max_len))
 
-    def extend_pe(self, x):
+    def extend_pe(self, x: torch.Tensor):
         """Reset the positional encodings."""
         if self.pe is not None:
             # self.pe contains both positive and negative parts
@@ -253,7 +253,8 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
         pe = torch.cat([pe_positive, pe_negative], dim=1)
         self.pe = pe.to(device=x.device, dtype=x.dtype)
 
-    def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
+    def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
+            -> Tuple[torch.Tensor, torch.Tensor]:
         """Add positional encoding.
 
         Args:

+ 3 - 1
cosyvoice/transformer/encoder.py

@@ -169,7 +169,7 @@ class BaseEncoder(torch.nn.Module):
             xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
         return xs
 
-    @torch.jit.ignore(drop=True)
+    @torch.jit.unused
     def forward_layers_checkpointed(self, xs: torch.Tensor,
                                     chunk_masks: torch.Tensor,
                                     pos_emb: torch.Tensor,
@@ -180,6 +180,7 @@ class BaseEncoder(torch.nn.Module):
                                                     mask_pad)
         return xs
 
+    @torch.jit.export
     def forward_chunk(
         self,
         xs: torch.Tensor,
@@ -270,6 +271,7 @@ class BaseEncoder(torch.nn.Module):
 
         return (xs, r_att_cache, r_cnn_cache)
 
+    @torch.jit.unused
     def forward_chunk_by_chunk(
         self,
         xs: torch.Tensor,

+ 36 - 0
cosyvoice/utils/common.py

@@ -101,3 +101,39 @@ def init_weights(m, mean=0.0, std=0.01):
     classname = m.__class__.__name__
     if classname.find("Conv") != -1:
         m.weight.data.normal_(mean, std)
+
+# Repetition Aware Sampling in VALL-E 2
+def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
+    top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
+    rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
+    if rep_num >= win_size * tau_r:
+        top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
+    return top_ids
+
+def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
+    prob, indices = [], []
+    cum_prob = 0.0
+    sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
+    for i in range(len(sorted_idx)):
+        # sampling both top-p and numbers.
+        if cum_prob < top_p and len(prob) < top_k:
+            cum_prob += sorted_value[i]
+            prob.append(sorted_value[i])
+            indices.append(sorted_idx[i])
+        else:
+            break
+    prob = torch.tensor(prob).to(weighted_scores)
+    indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
+    top_ids = indices[prob.multinomial(1, replacement=True)]
+    return top_ids
+
+def random_sampling(weighted_scores, decoded_tokens, sampling):
+    top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
+    return top_ids
+
+def fade_in_out(fade_in_mel, fade_out_mel, window):
+    device = fade_in_mel.device
+    fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
+    mel_overlap_len = int(window.shape[0] / 2)
+    fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
+    return fade_in_mel.to(device)

+ 4 - 0
cosyvoice/utils/file_utils.py

@@ -15,6 +15,10 @@
 
 import json
 import torchaudio
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+logging.basicConfig(level=logging.DEBUG,
+                    format='%(asctime)s %(levelname)s %(message)s')
 
 
 def read_lists(list_file):

+ 8 - 3
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
         num_blocks: 3
         dropout_rate: 0.1
         positional_dropout_rate: 0.1
-        attention_dropout_rate: 0
+        attention_dropout_rate: 0.0
         normalize_before: True
         input_layer: 'linear'
         pos_enc_layer_type: 'rel_pos_espnet'
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
         num_blocks: 7
         dropout_rate: 0.1
         positional_dropout_rate: 0.1
-        attention_dropout_rate: 0
+        attention_dropout_rate: 0.0
         input_layer: 'linear_legacy'
         pos_enc_layer_type: 'rel_pos_espnet'
         selfattention_layer_type: 'rel_selfattn'
         static_chunk_size: 1
+    sampling: !name:cosyvoice.utils.common.ras_sampling
+        top_p: 0.8
+        top_k: 25
+        win_size: 10
+        tau_r: 0.1
 
 flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
     input_size: 512
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
             in_channels: 320
             out_channels: 80
             channels: [256, 256]
-            dropout: 0
+            dropout: 0.0
             attention_head_dim: 64
             n_blocks: 4
             num_mid_blocks: 8

+ 8 - 3
examples/libritts/cosyvoice/conf/cosyvoice.yaml

@@ -31,7 +31,7 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
         num_blocks: 6
         dropout_rate: 0.1
         positional_dropout_rate: 0.1
-        attention_dropout_rate: 0
+        attention_dropout_rate: 0.0
         normalize_before: True
         input_layer: 'linear'
         pos_enc_layer_type: 'rel_pos_espnet'
@@ -49,11 +49,16 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
         num_blocks: 14
         dropout_rate: 0.1
         positional_dropout_rate: 0.1
-        attention_dropout_rate: 0
+        attention_dropout_rate: 0.0
         input_layer: 'linear_legacy'
         pos_enc_layer_type: 'rel_pos_espnet'
         selfattention_layer_type: 'rel_selfattn'
         static_chunk_size: 1
+    sampling: !name:cosyvoice.utils.common.ras_sampling
+        top_p: 0.8
+        top_k: 25
+        win_size: 10
+        tau_r: 0.1
 
 flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
     input_size: 512
@@ -97,7 +102,7 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
             in_channels: 320
             out_channels: 80
             channels: [256, 256]
-            dropout: 0
+            dropout: 0.0
             attention_head_dim: 64
             n_blocks: 4
             num_mid_blocks: 12

+ 4 - 1
runtime/python/grpc/client.py

@@ -61,8 +61,11 @@ def main():
             request.instruct_request.CopyFrom(instruct_request)
 
         response = stub.Inference(request)
+        tts_audio = b''
+        for r in response:
+            tts_audio += r.tts_audio
+        tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
         logging.info('save response to {}'.format(args.tts_wav))
-        tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
         torchaudio.save(args.tts_wav, tts_speech, target_sr)
         logging.info('get response')
 

+ 1 - 1
runtime/python/grpc/cosyvoice.proto

@@ -4,7 +4,7 @@ package cosyvoice;
 option go_package = "protos/";
 
 service CosyVoice{
-  rpc Inference(Request) returns (Response) {}
+  rpc Inference(Request) returns (stream Response) {}
 }
 
 message Request{

+ 4 - 3
runtime/python/grpc/server.py

@@ -54,9 +54,10 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
             model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
 
         logging.info('send inference response')
-        response = cosyvoice_pb2.Response()
-        response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
-        return response
+        for i in model_output:
+            response = cosyvoice_pb2.Response()
+            response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
+            yield response
 
 def main():
     grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)

+ 15 - 26
webui.py

@@ -24,14 +24,8 @@ import torchaudio
 import random
 import librosa
 
-import logging
-logging.getLogger('matplotlib').setLevel(logging.WARNING)
-
 from cosyvoice.cli.cosyvoice import CosyVoice
-from cosyvoice.utils.file_utils import load_wav, speed_change
-
-logging.basicConfig(level=logging.DEBUG,
-                    format='%(asctime)s %(levelname)s %(message)s')
+from cosyvoice.utils.file_utils import load_wav, speed_change, logging
 
 def generate_seed():
     seed = random.randint(1, 100000000)
@@ -63,10 +57,11 @@ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成
                  '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
                  '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
                  '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
+stream_mode_list = [('否', False), ('是', True)]
 def change_instruction(mode_checkbox_group):
     return instruct_dict[mode_checkbox_group]
 
-def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor):
+def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor):
     if prompt_wav_upload is not None:
         prompt_wav = prompt_wav_upload
     elif prompt_wav_record is not None:
@@ -117,32 +112,25 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
     if mode_checkbox_group == '预训练音色':
         logging.info('get sft inference request')
         set_all_random_seed(seed)
-        output = cosyvoice.inference_sft(tts_text, sft_dropdown)
+        for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
+            yield (target_sr,  i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '3s极速复刻':
         logging.info('get zero_shot inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
-        output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
+        for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
+            yield (target_sr,  i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '跨语种复刻':
         logging.info('get cross_lingual inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
-        output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
+        for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
+            yield (target_sr,  i['tts_speech'].numpy().flatten())
     else:
         logging.info('get instruct inference request')
         set_all_random_seed(seed)
-        output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text)
-
-    if speed_factor != 1.0:
-        try:
-            audio_data, sample_rate = speed_change(output["tts_speech"], target_sr, str(speed_factor))
-            audio_data = audio_data.numpy().flatten()
-        except Exception as e:
-            print(f"Failed to change speed of audio: \n{e}")
-    else:
-        audio_data = output['tts_speech'].numpy().flatten()
-
-    return (target_sr, audio_data)
+        for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
+            yield (target_sr,  i['tts_speech'].numpy().flatten())
 
 def main():
     with gr.Blocks() as demo:
@@ -155,6 +143,7 @@ def main():
             mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
             instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
             sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
+            stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
             with gr.Column(scale=0.25):
                 seed_button = gr.Button(value="\U0001F3B2")
                 seed = gr.Number(value=0, label="随机推理种子")
@@ -167,11 +156,11 @@ def main():
 
         generate_button = gr.Button("生成音频")
 
-        audio_output = gr.Audio(label="合成音频")
+        audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
 
         seed_button.click(generate_seed, inputs=[], outputs=seed)
         generate_button.click(generate_audio,
-                              inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor],
+                              inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor],
                               outputs=[audio_output])
         mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
     demo.queue(max_size=4, default_concurrency_limit=2)
@@ -184,7 +173,7 @@ if __name__ == '__main__':
                         default=8000)
     parser.add_argument('--model_dir',
                         type=str,
-                        default='iic/CosyVoice-300M',
+                        default='pretrained_models/CosyVoice-300M',
                         help='local path or modelscope repo id')
     args = parser.parse_args()
     cosyvoice = CosyVoice(args.model_dir)