Sfoglia il codice sorgente

Merge pull request #354 from FunAudioLLM/dev/lyuxiang.lx

fix readme
Xiang Lyu 1 anno fa
parent
commit
33a585374a

+ 1 - 1
.github/workflows/lint.yml

@@ -51,5 +51,5 @@ jobs:
           set -eux
           pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
           flake8 --version
-          flake8 --max-line-length 120 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
+          flake8 --max-line-length 150 --ignore B006,B008,B905,C408,E402,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
           if [ $? != 0 ]; then exit 1; fi

+ 1 - 1
README.md

@@ -12,7 +12,7 @@ For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVo
     - [x] WeTextProcessing support when ttsfrd is not avaliable
     - [x] Fastapi server and client
 
-- [ ] 2024/08
+- [x] 2024/08
 
     - [x] Repetition Aware Sampling(RAS) inference for llm stability
     - [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization

+ 4 - 1
cosyvoice/bin/export_jit.py

@@ -19,12 +19,13 @@ import logging
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 import os
 import sys
+import torch
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
-import torch
 from cosyvoice.cli.cosyvoice import CosyVoice
 
+
 def get_args():
     parser = argparse.ArgumentParser(description='export your model for deployment')
     parser.add_argument('--model_dir',
@@ -35,6 +36,7 @@ def get_args():
     print(args)
     return args
 
+
 def main():
     args = get_args()
     logging.basicConfig(level=logging.DEBUG,
@@ -67,5 +69,6 @@ def main():
     script = torch.jit.optimize_for_inference(script)
     script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
 
+
 if __name__ == '__main__':
     main()

+ 7 - 4
cosyvoice/bin/export_onnx.py

@@ -20,13 +20,13 @@ import logging
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 import os
 import sys
-ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
-sys.path.append('{}/../..'.format(ROOT_DIR))
-sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
 import onnxruntime
 import random
 import torch
 from tqdm import tqdm
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
 from cosyvoice.cli.cosyvoice import CosyVoice
 
 
@@ -50,6 +50,7 @@ def get_args():
     print(args)
     return args
 
+
 def main():
     args = get_args()
     logging.basicConfig(level=logging.DEBUG,
@@ -89,7 +90,8 @@ def main():
     option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
     option.intra_op_num_threads = 1
     providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
-    estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers)
+    estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+                                                  sess_options=option, providers=providers)
 
     for _ in tqdm(range(10)):
         x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
@@ -105,5 +107,6 @@ def main():
         output_onnx = estimator_onnx.run(None, ort_inputs)[0]
         torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
 
+
 if __name__ == "__main__":
     main()

+ 4 - 6
cosyvoice/bin/inference.py

@@ -18,16 +18,15 @@ import argparse
 import logging
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 import os
-
 import torch
 from torch.utils.data import DataLoader
 import torchaudio
 from hyperpyyaml import load_hyperpyyaml
 from tqdm import tqdm
 from cosyvoice.cli.model import CosyVoiceModel
-
 from cosyvoice.dataset.dataset import Dataset
 
+
 def get_args():
     parser = argparse.ArgumentParser(description='inference with your model')
     parser.add_argument('--config', required=True, help='config file')
@@ -66,7 +65,8 @@ def main():
     model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
     model.load(args.llm_model, args.flow_model, args.hifigan_model)
 
-    test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
+    test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
+                           tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
     test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
 
     del configs
@@ -74,13 +74,11 @@ def main():
     fn = os.path.join(args.result_dir, 'wav.scp')
     f = open(fn, 'w')
     with torch.no_grad():
-        for batch_idx, batch in tqdm(enumerate(test_data_loader)):
+        for _, batch in tqdm(enumerate(test_data_loader)):
             utts = batch["utts"]
             assert len(utts) == 1, "inference mode only support batchsize 1"
-            text = batch["text"]
             text_token = batch["text_token"].to(device)
             text_token_len = batch["text_token_len"].to(device)
-            tts_text = batch["tts_text"]
             tts_index = batch["tts_index"]
             tts_text_token = batch["tts_text_token"].to(device)
             tts_text_token_len = batch["tts_text_token_len"].to(device)

+ 1 - 0
cosyvoice/bin/train.py

@@ -132,5 +132,6 @@ def main():
         executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
         dist.destroy_process_group(group_join)
 
+
 if __name__ == '__main__':
     main()

+ 3 - 2
cosyvoice/cli/cosyvoice.py

@@ -20,6 +20,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 from cosyvoice.cli.model import CosyVoiceModel
 from cosyvoice.utils.file_utils import logging
 
+
 class CosyVoice:
 
     def __init__(self, model_dir, load_jit=True, load_onnx=True):
@@ -42,8 +43,8 @@ class CosyVoice:
                         '{}/hift.pt'.format(model_dir))
         if load_jit:
             self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
-                                    '{}/llm.llm.fp16.zip'.format(model_dir),
-                                    '{}/flow.encoder.fp32.zip'.format(model_dir))
+                                '{}/llm.llm.fp16.zip'.format(model_dir),
+                                '{}/flow.encoder.fp32.zip'.format(model_dir))
         if load_onnx:
             self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
         del configs

+ 16 - 11
cosyvoice/cli/frontend.py

@@ -50,7 +50,9 @@ class CosyVoiceFrontEnd:
         option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
         option.intra_op_num_threads = 1
         self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
-        self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
+        self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
+                                                                     providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
+                                                                                "CPUExecutionProvider"])
         if os.path.exists(spk2info):
             self.spk2info = torch.load(spk2info, map_location=self.device)
         self.instruct = instruct
@@ -60,7 +62,8 @@ class CosyVoiceFrontEnd:
         if self.use_ttsfrd:
             self.frd = ttsfrd.TtsFrontendEngine()
             ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
-            assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
+            assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
+                'failed to initialize ttsfrd resource'
             self.frd.set_lang_type('pinyin')
             self.frd.enable_pinyin_mix(True)
             self.frd.set_breakmodel_index(1)
@@ -76,8 +79,11 @@ class CosyVoiceFrontEnd:
 
     def _extract_speech_token(self, speech):
         feat = whisper.log_mel_spectrogram(speech, n_mels=128)
-        speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
-                                                                self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+        speech_token = self.speech_tokenizer_session.run(None,
+                                                         {self.speech_tokenizer_session.get_inputs()[0].name:
+                                                          feat.detach().cpu().numpy(),
+                                                          self.speech_tokenizer_session.get_inputs()[1].name:
+                                                          np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
         speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
         speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
         return speech_token, speech_token_len
@@ -88,7 +94,8 @@ class CosyVoiceFrontEnd:
                            dither=0,
                            sample_frequency=16000)
         feat = feat - feat.mean(dim=0, keepdim=True)
-        embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+        embedding = self.campplus_session.run(None,
+                                              {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
         embedding = torch.tensor([embedding]).to(self.device)
         return embedding
 
@@ -112,18 +119,16 @@ class CosyVoiceFrontEnd:
             text = text.replace(" - ", ",")
             text = remove_bracket(text)
             text = re.sub(r'[,,]+$', '。', text)
-            texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
-                                                token_min_n=60, merge_len=20,
-                                                comma_split=False)]
+            texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
+                                         token_min_n=60, merge_len=20, comma_split=False))
         else:
             if self.use_ttsfrd:
                 text = self.frd.get_frd_extra_info(text, 'input')
             else:
                 text = self.en_tn_model.normalize(text)
             text = spell_out_number(text, self.inflect_parser)
-            texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
-                                                token_min_n=60, merge_len=20,
-                                                comma_split=False)]
+            texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
+                                         token_min_n=60, merge_len=20, comma_split=False))
         if split is False:
             return text
         return texts

+ 34 - 33
cosyvoice/cli/model.py

@@ -18,7 +18,7 @@ import time
 from contextlib import nullcontext
 import uuid
 from cosyvoice.utils.common import fade_in_out
-import numpy as np
+
 
 class CosyVoiceModel:
 
@@ -80,27 +80,27 @@ class CosyVoiceModel:
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),
-                                                text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
-                                                prompt_text=prompt_text.to(self.device),
-                                                prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
-                                                prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                                prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                                embedding=llm_embedding.to(self.device).half(),
-                                                sampling=25,
-                                                max_token_text_ratio=30,
-                                                min_token_text_ratio=3):
+                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_text=prompt_text.to(self.device),
+                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                        embedding=llm_embedding.to(self.device).half(),
+                                        sampling=25,
+                                        max_token_text_ratio=30,
+                                        min_token_text_ratio=3):
                 self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
         with self.flow_hift_context:
             tts_mel = self.flow.inference(token=token.to(self.device),
-                                        token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
-                                        prompt_token=prompt_token.to(self.device),
-                                        prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
-                                        prompt_feat=prompt_feat.to(self.device),
-                                        prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
-                                        embedding=embedding.to(self.device))
+                                          token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                          prompt_token=prompt_token.to(self.device),
+                                          prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                          prompt_feat=prompt_feat.to(self.device),
+                                          prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                          embedding=embedding.to(self.device))
             # mel overlap fade in out
             if self.mel_overlap_dict[uuid] is not None:
                 tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
@@ -129,7 +129,8 @@ class CosyVoiceModel:
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
         with self.lock:
-            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
+            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
+            self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
         p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         p.start()
         if stream is True:
@@ -140,12 +141,12 @@ class CosyVoiceModel:
                     this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
                     with self.flow_hift_context:
                         this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                                    prompt_token=flow_prompt_speech_token,
-                                                    prompt_feat=prompt_speech_feat,
-                                                    embedding=flow_embedding,
-                                                    uuid=this_uuid,
-                                                    finalize=False)
-                    yield  {'tts_speech': this_tts_speech.cpu()}
+                                                         prompt_token=flow_prompt_speech_token,
+                                                         prompt_feat=prompt_speech_feat,
+                                                         embedding=flow_embedding,
+                                                         uuid=this_uuid,
+                                                         finalize=False)
+                    yield {'tts_speech': this_tts_speech.cpu()}
                     with self.lock:
                         self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
                     # increase token_hop_len for better speech quality
@@ -157,11 +158,11 @@ class CosyVoiceModel:
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
                 this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                            prompt_token=flow_prompt_speech_token,
-                                            prompt_feat=prompt_speech_feat,
-                                            embedding=flow_embedding,
-                                            uuid=this_uuid,
-                                            finalize=True)
+                                                 prompt_token=flow_prompt_speech_token,
+                                                 prompt_feat=prompt_speech_feat,
+                                                 embedding=flow_embedding,
+                                                 uuid=this_uuid,
+                                                 finalize=True)
             yield {'tts_speech': this_tts_speech.cpu()}
         else:
             # deal with all tokens
@@ -169,11 +170,11 @@ class CosyVoiceModel:
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
                 this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                            prompt_token=flow_prompt_speech_token,
-                                            prompt_feat=prompt_speech_feat,
-                                            embedding=flow_embedding,
-                                            uuid=this_uuid,
-                                            finalize=True)
+                                                 prompt_token=flow_prompt_speech_token,
+                                                 prompt_feat=prompt_speech_feat,
+                                                 embedding=flow_embedding,
+                                                 uuid=this_uuid,
+                                                 finalize=True)
             yield {'tts_speech': this_tts_speech.cpu()}
         with self.lock:
             self.tts_speech_token_dict.pop(this_uuid)

+ 1 - 1
cosyvoice/dataset/dataset.py

@@ -148,7 +148,7 @@ def Dataset(data_list_file,
             tts_data = json.load(f)
         utt2lists = read_json_lists(prompt_utt2data)
         # filter unnecessary file in inference mode
-        lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
+        lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
     dataset = DataList(lists,
                        shuffle=shuffle,
                        partition=partition)

+ 2 - 1
cosyvoice/dataset/processor.py

@@ -23,7 +23,7 @@ import torch.nn.functional as F
 
 torchaudio.set_audio_backend('soundfile')
 
-AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
+AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
 
 
 def parquet_opener(data, mode='train', tts_data={}):
@@ -54,6 +54,7 @@ def parquet_opener(data, mode='train', tts_data={}):
         except Exception as ex:
             logging.warning('Failed to open {}, ex info {}'.format(url, ex))
 
+
 def filter(data,
            max_length=10240,
            min_length=10,

+ 1 - 2
cosyvoice/flow/decoder.py

@@ -74,7 +74,7 @@ class ConditionalDecoder(nn.Module):
             )
             self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
 
-        for i in range(num_mid_blocks):
+        for _ in range(num_mid_blocks):
             input_channel = channels[-1]
             out_channels = channels[-1]
             resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
@@ -126,7 +126,6 @@ class ConditionalDecoder(nn.Module):
         self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
         self.initialize_weights()
 
-
     def initialize_weights(self):
         for m in self.modules():
             if isinstance(m, nn.Conv1d):

+ 7 - 2
cosyvoice/flow/flow.py

@@ -33,8 +33,13 @@ class MaskedDiffWithXvec(torch.nn.Module):
                  encoder: torch.nn.Module = None,
                  length_regulator: torch.nn.Module = None,
                  decoder: torch.nn.Module = None,
-                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
-                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
+                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
+                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
+                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
+                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
+                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
+                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
+                                        'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
         super().__init__()
         self.input_size = input_size
         self.output_size = output_size

+ 1 - 0
cosyvoice/flow/flow_matching.py

@@ -15,6 +15,7 @@ import torch
 import torch.nn.functional as F
 from matcha.models.components.flow_matching import BASECFM
 
+
 class ConditionalCFM(BASECFM):
     def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
         super().__init__(

+ 7 - 4
cosyvoice/flow/length_regulator.py

@@ -38,6 +38,8 @@ This code is modified from https://github.com/jik876/hifi-gan
  https://github.com/NVIDIA/BigVGAN
 
 """
+
+
 class ResBlock(torch.nn.Module):
     """Residual block module in HiFiGAN/BigVGAN."""
     def __init__(
@@ -100,6 +102,7 @@ class ResBlock(torch.nn.Module):
             remove_weight_norm(self.convs1[idx])
             remove_weight_norm(self.convs2[idx])
 
+
 class SineGen(torch.nn.Module):
     """ Definition of sine generator
     SineGen(samp_rate, harmonic_num = 0,
@@ -286,8 +289,7 @@ class HiFTGenerator(nn.Module):
         self.source_resblocks = nn.ModuleList()
         downsample_rates = [1] + upsample_rates[::-1][:-1]
         downsample_cum_rates = np.cumprod(downsample_rates)
-        for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
-                                          source_resblock_dilation_sizes)):
+        for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
             if u == 1:
                 self.source_downs.append(
                     Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
@@ -304,7 +306,7 @@ class HiFTGenerator(nn.Module):
         self.resblocks = nn.ModuleList()
         for i in range(len(self.ups)):
             ch = base_channels // (2**(i + 1))
-            for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+            for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
                 self.resblocks.append(ResBlock(ch, k, d))
 
         self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
@@ -332,7 +334,8 @@ class HiFTGenerator(nn.Module):
         magnitude = torch.clip(magnitude, max=1e2)
         real = magnitude * torch.cos(phase)
         img = magnitude * torch.sin(phase)
-        inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
+        inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
+                                        self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
         return inverse_transform
 
     def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:

+ 10 - 5
cosyvoice/llm/llm.py

@@ -80,7 +80,8 @@ class TransformerLM(torch.nn.Module):
     def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
         text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
-        lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
+        lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
+                    for i in range(len(text_token))]
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
         lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
         return lm_input, lm_input_len
@@ -104,7 +105,8 @@ class TransformerLM(torch.nn.Module):
         embedding = batch['embedding'].to(device)
 
         # 1. prepare llm_target
-        lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
+        lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
+                                  [self.speech_token_size]) for i in range(text_token.size(0))]
         lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
 
         # 1. encode text_token
@@ -124,7 +126,8 @@ class TransformerLM(torch.nn.Module):
         speech_token = self.speech_embedding(speech_token)
 
         # 5. unpad and pad
-        lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
+        lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
+                                                         task_id_emb, speech_token, speech_token_len)
 
         # 6. run lm forward
         lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
@@ -194,8 +197,10 @@ class TransformerLM(torch.nn.Module):
         offset = 0
         att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
         for i in range(max_len):
-            y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
-                                                                  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
+            y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1,
+                                                                  att_cache=att_cache, cnn_cache=cnn_cache,
+                                                                  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
+                                                                                                 device=lm_input.device)).to(torch.bool))
             logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
             top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
             if top_ids == self.speech_token_size:

+ 2 - 2
cosyvoice/transformer/embedding.py

@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
 
     """
 
-    def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
+    def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
         """Construct an PositionalEncoding object."""
         super(EspnetRelPositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -289,6 +289,6 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
         """
         pos_emb = self.pe[
             :,
-            self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
+            self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
         ]
         return pos_emb

+ 6 - 1
cosyvoice/utils/common.py

@@ -102,6 +102,7 @@ def init_weights(m, mean=0.0, std=0.01):
     if classname.find("Conv") != -1:
         m.weight.data.normal_(mean, std)
 
+
 # Repetition Aware Sampling in VALL-E 2
 def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
     top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
@@ -110,6 +111,7 @@ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25,
         top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
     return top_ids
 
+
 def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
     prob, indices = [], []
     cum_prob = 0.0
@@ -127,13 +129,16 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
     top_ids = indices[prob.multinomial(1, replacement=True)]
     return top_ids
 
+
 def random_sampling(weighted_scores, decoded_tokens, sampling):
     top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
     return top_ids
 
+
 def fade_in_out(fade_in_mel, fade_out_mel, window):
     device = fade_in_mel.device
     fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
     mel_overlap_len = int(window.shape[0] / 2)
-    fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
+    fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \
+        fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
     return fade_in_mel.to(device)

+ 2 - 1
cosyvoice/utils/executor.py

@@ -70,7 +70,8 @@ class Executor:
                 info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
                 log_per_step(writer, info_dict)
                 # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
-                if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and (batch_idx + 1) % info_dict["accum_grad"] == 0:
+                if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
+                   (batch_idx + 1) % info_dict["accum_grad"] == 0:
                     dist.barrier()
                     self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
                     model.train()

+ 3 - 0
cosyvoice/utils/file_utils.py

@@ -28,6 +28,7 @@ def read_lists(list_file):
             lists.append(line.strip())
     return lists
 
+
 def read_json_lists(list_file):
     lists = read_lists(list_file)
     results = {}
@@ -36,6 +37,7 @@ def read_json_lists(list_file):
             results.update(json.load(fin))
     return results
 
+
 def load_wav(wav, target_sr):
     speech, sample_rate = torchaudio.load(wav)
     speech = speech.mean(dim=0, keepdim=True)
@@ -44,6 +46,7 @@ def load_wav(wav, target_sr):
         speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
     return speech
 
+
 def speed_change(waveform, sample_rate, speed_factor: str):
     effects = [
         ["tempo", speed_factor],  # speed_factor

+ 1 - 0
cosyvoice/utils/frontend_utils.py

@@ -15,6 +15,7 @@
 import re
 chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
 
+
 # whether contain chinese character
 def contains_chinese(text):
     return bool(chinese_char_pattern.search(text))

+ 1 - 2
cosyvoice/utils/scheduler.py

@@ -567,8 +567,7 @@ class NoamAnnealing(_LRScheduler):
                  min_lr=0.0,
                  last_epoch=-1):
         self._normalize = d_model**(-0.5)
-        assert not (warmup_steps is not None
-                    and warmup_ratio is not None), \
+        assert not (warmup_steps is not None and warmup_ratio is not None), \
             "Either use particular number of step or ratio"
         assert warmup_ratio is None or max_steps is not None, \
             "If there is a ratio, there should be a total steps"

+ 2 - 2
cosyvoice/utils/train_utils.py

@@ -69,7 +69,6 @@ def init_dataset_and_dataloader(args, configs):
     return train_dataset, cv_dataset, train_data_loader, cv_data_loader
 
 
-
 def check_modify_and_save_config(args, configs):
     if args.train_engine == "torch_ddp":
         configs['train_conf']["dtype"] = 'fp32'
@@ -84,7 +83,8 @@ def check_modify_and_save_config(args, configs):
             configs['train_conf']["dtype"] = "fp32"
         assert ds_configs["train_micro_batch_size_per_gpu"] == 1
         # if use deepspeed, override ddp config
-        configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
+        configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
+                                                     configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
         configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
         configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
         configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]

+ 2 - 0
examples/libritts/cosyvoice/local/prepare_data.py

@@ -7,6 +7,7 @@ from tqdm import tqdm
 
 logger = logging.getLogger()
 
+
 def main():
     wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
 
@@ -41,6 +42,7 @@ def main():
             f.write('{} {}\n'.format(k, ' '.join(v)))
     return
 
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument('--src_dir',

+ 1 - 1
examples/libritts/cosyvoice/run.sh

@@ -83,7 +83,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
   fi
   cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
   cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
-  for model in llm; do
+  for model in llm flow; do
     torchrun --nnodes=1 --nproc_per_node=$num_gpus \
         --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
       cosyvoice/bin/train.py \

+ 2 - 0
examples/magicdata-read/cosyvoice/local/prepare_data.py

@@ -6,6 +6,7 @@ from tqdm import tqdm
 
 logger = logging.getLogger()
 
+
 def main():
     utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
     with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
@@ -40,6 +41,7 @@ def main():
             f.write('{} {}\n'.format(k, ' '.join(v)))
     return
 
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument('--src_dir',

+ 1 - 1
examples/magicdata-read/cosyvoice/run.sh

@@ -83,7 +83,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
   fi
   cp data/train/parquet/data.list data/train.data.list
   cp data/dev/parquet/data.list data/dev.data.list
-  for model in llm; do
+  for model in llm flow; do
     torchrun --nnodes=1 --nproc_per_node=$num_gpus \
         --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
       cosyvoice/bin/train.py \

+ 4 - 2
runtime/python/fastapi/client.py

@@ -38,7 +38,7 @@ def main():
         payload = {
             'tts_text': args.tts_text,
         }
-        files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
+        files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
         response = requests.request("GET", url, data=payload, files=files, stream=True)
     else:
         payload = {
@@ -55,6 +55,7 @@ def main():
     torchaudio.save(args.tts_wav, tts_speech, target_sr)
     logging.info('get response')
 
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument('--host',
@@ -81,7 +82,8 @@ if __name__ == "__main__":
                         default='../../../zero_shot_prompt.wav')
     parser.add_argument('--instruct_text',
                         type=str,
-                        default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
+                        default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
+                                 Fights with fervor for justice, but struggles with impulsiveness.')
     parser.add_argument('--tts_wav',
                         type=str,
                         default='demo.wav')

+ 11 - 5
runtime/python/fastapi/server.py

@@ -13,9 +13,6 @@
 # limitations under the License.
 import os
 import sys
-ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
-sys.path.append('{}/../../..'.format(ROOT_DIR))
-sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
 import argparse
 import logging
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
@@ -24,6 +21,9 @@ from fastapi.responses import StreamingResponse
 from fastapi.middleware.cors import CORSMiddleware
 import uvicorn
 import numpy as np
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../../..'.format(ROOT_DIR))
+sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
 from cosyvoice.cli.cosyvoice import CosyVoice
 from cosyvoice.utils.file_utils import load_wav
 
@@ -36,34 +36,40 @@ app.add_middleware(
     allow_methods=["*"],
     allow_headers=["*"])
 
+
 def generate_data(model_output):
     for i in model_output:
         tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
         yield tts_audio
 
+
 @app.get("/inference_sft")
 async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
     model_output = cosyvoice.inference_sft(tts_text, spk_id)
     return StreamingResponse(generate_data(model_output))
 
+
 @app.get("/inference_zero_shot")
 async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
     prompt_speech_16k = load_wav(prompt_wav.file, 16000)
     model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
     return StreamingResponse(generate_data(model_output))
 
+
 @app.get("/inference_cross_lingual")
 async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
     prompt_speech_16k = load_wav(prompt_wav.file, 16000)
     model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
     return StreamingResponse(generate_data(model_output))
 
+
 @app.get("/inference_instruct")
 async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
     model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
     return StreamingResponse(generate_data(model_output))
 
-if __name__=='__main__':
+
+if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--port',
                         type=int,
@@ -74,4 +80,4 @@ if __name__=='__main__':
                         help='local path or modelscope repo id')
     args = parser.parse_args()
     cosyvoice = CosyVoice(args.model_dir)
-    uvicorn.run(app, host="127.0.0.1", port=args.port)
+    uvicorn.run(app, host="127.0.0.1", port=args.port)

+ 2 - 1
runtime/python/grpc/client.py

@@ -96,7 +96,8 @@ if __name__ == "__main__":
                         default='../../../zero_shot_prompt.wav')
     parser.add_argument('--instruct_text',
                         type=str,
-                        default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
+                        default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
+                                 Fights with fervor for justice, but struggles with impulsiveness.')
     parser.add_argument('--tts_wav',
                         type=str,
                         default='demo.wav')

+ 11 - 5
runtime/python/grpc/server.py

@@ -13,9 +13,6 @@
 # limitations under the License.
 import os
 import sys
-ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
-sys.path.append('{}/../../..'.format(ROOT_DIR))
-sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
 from concurrent import futures
 import argparse
 import cosyvoice_pb2
@@ -25,11 +22,15 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
 import grpc
 import torch
 import numpy as np
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../../..'.format(ROOT_DIR))
+sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
 from cosyvoice.cli.cosyvoice import CosyVoice
 
 logging.basicConfig(level=logging.DEBUG,
                     format='%(asctime)s %(levelname)s %(message)s')
 
+
 class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
     def __init__(self, args):
         self.cosyvoice = CosyVoice(args.model_dir)
@@ -43,7 +44,9 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
             logging.info('get zero_shot inference request')
             prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
             prompt_speech_16k = prompt_speech_16k.float() / (2**15)
-            model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k)
+            model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
+                                                              request.zero_shot_request.prompt_text,
+                                                              prompt_speech_16k)
         elif request.HasField('cross_lingual_request'):
             logging.info('get cross_lingual inference request')
             prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
@@ -51,7 +54,9 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
             model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
         else:
             logging.info('get instruct inference request')
-            model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
+            model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
+                                                             request.instruct_request.spk_id,
+                                                             request.instruct_request.instruct_text)
 
         logging.info('send inference response')
         for i in model_output:
@@ -59,6 +64,7 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
             response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
             yield response
 
+
 def main():
     grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
     cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)

+ 1 - 0
tools/extract_embedding.py

@@ -59,6 +59,7 @@ def main(args):
     torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
     torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
 
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument('--dir',

+ 1 - 0
tools/make_parquet_list.py

@@ -53,6 +53,7 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
         json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
     logging.info('spend time {}'.format(time.time() - start_time))
 
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument('--num_utts_per_parquet',

+ 37 - 26
webui.py

@@ -13,9 +13,6 @@
 # limitations under the License.
 import os
 import sys
-ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
-sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
-
 import argparse
 import gradio as gr
 import numpy as np
@@ -23,9 +20,19 @@ import torch
 import torchaudio
 import random
 import librosa
-
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
 from cosyvoice.cli.cosyvoice import CosyVoice
-from cosyvoice.utils.file_utils import load_wav, speed_change, logging
+from cosyvoice.utils.file_utils import load_wav, logging
+
+inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
+instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
+                 '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
+                 '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
+                 '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
+stream_mode_list = [('否', False), ('是', True)]
+max_val = 0.8
+
 
 def generate_seed():
     seed = random.randint(1, 100000000)
@@ -34,13 +41,14 @@ def generate_seed():
         "value": seed
     }
 
+
 def set_all_random_seed(seed):
     random.seed(seed)
     np.random.seed(seed)
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
 
-max_val = 0.8
+
 def postprocess(speech, top_db=60, hop_length=220, win_length=440):
     speech, _ = librosa.effects.trim(
         speech, top_db=top_db,
@@ -52,16 +60,13 @@ def postprocess(speech, top_db=60, hop_length=220, win_length=440):
     speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
     return speech
 
-inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
-instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
-                 '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
-                 '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
-                 '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
-stream_mode_list = [('否', False), ('是', True)]
+
 def change_instruction(mode_checkbox_group):
     return instruct_dict[mode_checkbox_group]
 
-def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor):
+
+def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
+                   seed, stream, speed_factor):
     if prompt_wav_upload is not None:
         prompt_wav = prompt_wav_upload
     elif prompt_wav_record is not None:
@@ -72,31 +77,31 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
     if mode_checkbox_group in ['自然语言控制']:
         if cosyvoice.frontend.instruct is False:
             gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
-            return (target_sr, default_data)
+            yield (target_sr, default_data)
         if instruct_text == '':
             gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
-            return (target_sr, default_data)
+            yield (target_sr, default_data)
         if prompt_wav is not None or prompt_text != '':
             gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
     # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
     if mode_checkbox_group in ['跨语种复刻']:
         if cosyvoice.frontend.instruct is True:
             gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
-            return (target_sr, default_data)
+            yield (target_sr, default_data)
         if instruct_text != '':
             gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
         if prompt_wav is None:
             gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
-            return (target_sr, default_data)
+            yield (target_sr, default_data)
         gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
     # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
     if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
         if prompt_wav is None:
             gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
-            return (target_sr, default_data)
+            yield (target_sr, default_data)
         if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
             gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
-            return (target_sr, default_data)
+            yield (target_sr, default_data)
     # sft mode only use sft_dropdown
     if mode_checkbox_group in ['预训练音色']:
         if instruct_text != '' or prompt_wav is not None or prompt_text != '':
@@ -105,7 +110,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
     if mode_checkbox_group in ['3s极速复刻']:
         if prompt_text == '':
             gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
-            return (target_sr, default_data)
+            yield (target_sr, default_data)
         if instruct_text != '':
             gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
 
@@ -113,28 +118,32 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
         logging.info('get sft inference request')
         set_all_random_seed(seed)
         for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
-            yield (target_sr,  i['tts_speech'].numpy().flatten())
+            yield (target_sr, i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '3s极速复刻':
         logging.info('get zero_shot inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
         for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
-            yield (target_sr,  i['tts_speech'].numpy().flatten())
+            yield (target_sr, i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '跨语种复刻':
         logging.info('get cross_lingual inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
         for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
-            yield (target_sr,  i['tts_speech'].numpy().flatten())
+            yield (target_sr, i['tts_speech'].numpy().flatten())
     else:
         logging.info('get instruct inference request')
         set_all_random_seed(seed)
         for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
-            yield (target_sr,  i['tts_speech'].numpy().flatten())
+            yield (target_sr, i['tts_speech'].numpy().flatten())
+
 
 def main():
     with gr.Blocks() as demo:
-        gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
+        gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
+                    预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
+                    [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
+                    [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
         gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
 
         tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
@@ -160,12 +169,14 @@ def main():
 
         seed_button.click(generate_seed, inputs=[], outputs=seed)
         generate_button.click(generate_audio,
-                              inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor],
+                              inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
+                                      seed, stream, speed_factor],
                               outputs=[audio_output])
         mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
     demo.queue(max_size=4, default_concurrency_limit=2)
     demo.launch(server_name='0.0.0.0', server_port=args.port)
 
+
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--port',