Forráskód Böngészése

Merge pull request #810 from FunAudioLLM/dev/lyuxiang.lx

add some instruction and assert
Xiang Lyu 11 hónapja
szülő
commit
88f6a8e2fa
6 módosított fájl, 52 hozzáadás és 63 törlés
  1. 3 6
      README.md
  2. 22 20
      cosyvoice/cli/cosyvoice.py
  3. 11 35
      cosyvoice/cli/frontend.py
  4. 2 0
      cosyvoice/cli/model.py
  5. 12 0
      cosyvoice/utils/class_utils.py
  6. 2 2
      webui.py

+ 3 - 6
README.md

@@ -121,13 +121,10 @@ We strongly recommend using `CosyVoice2-0.5B` for better performance.
 For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
 For sft inference, please use `CosyVoice-300M-SFT` model.
 For instruct inference, please use `CosyVoice-300M-Instruct` model.
-First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
-
-``` sh
-export PYTHONPATH=third_party/Matcha-TTS
-```
 
 ``` python
+import sys
+sys.path.append('third_party/Matcha-TTS')
 from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
 from cosyvoice.utils.file_utils import load_wav
 import torchaudio
@@ -161,7 +158,7 @@ print(cosyvoice.list_available_spks())
 for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
     torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 
-cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference
+cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') # or change to pretrained_models/CosyVoice-300M-25Hz for 25Hz inference
 # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
 prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
 for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):

+ 22 - 20
cosyvoice/cli/cosyvoice.py

@@ -20,23 +20,24 @@ import torch
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
 from cosyvoice.utils.file_utils import logging
+from cosyvoice.utils.class_utils import get_model_type
 
 
 class CosyVoice:
 
     def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
-        instruct = True if '-Instruct' in model_dir else False
+        self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
             model_dir = snapshot_download(model_dir)
         with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
             configs = load_hyperpyyaml(f)
+        assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
         self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                           configs['feat_extractor'],
                                           '{}/campplus.onnx'.format(model_dir),
                                           '{}/speech_tokenizer_v1.onnx'.format(model_dir),
                                           '{}/spk2info.pt'.format(model_dir),
-                                          instruct,
                                           configs['allowed_special'])
         self.sample_rate = configs['sample_rate']
         if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
@@ -85,8 +86,6 @@ class CosyVoice:
                 start_time = time.time()
 
     def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
-        if self.frontend.instruct is True and isinstance(self.model, CosyVoiceModel):
-            raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
             model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
             start_time = time.time()
@@ -98,8 +97,8 @@ class CosyVoice:
                 start_time = time.time()
 
     def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
-        assert isinstance(self.model, CosyVoiceModel)
-        if self.frontend.instruct is False:
+        assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
+        if self.instruct is False:
             raise ValueError('{} do not support instruct inference'.format(self.model_dir))
         instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
@@ -112,18 +111,6 @@ class CosyVoice:
                 yield model_output
                 start_time = time.time()
 
-    def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
-        assert isinstance(self.model, CosyVoice2Model)
-        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
-            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
-            start_time = time.time()
-            logging.info('synthesis text {}'.format(i))
-            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
-                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
-                yield model_output
-                start_time = time.time()
-
     def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
         model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
         start_time = time.time()
@@ -137,18 +124,18 @@ class CosyVoice:
 class CosyVoice2(CosyVoice):
 
     def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
-        instruct = True if '-Instruct' in model_dir else False
+        self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
             model_dir = snapshot_download(model_dir)
         with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
             configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
+        assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
         self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                           configs['feat_extractor'],
                                           '{}/campplus.onnx'.format(model_dir),
                                           '{}/speech_tokenizer_v2.onnx'.format(model_dir),
                                           '{}/spk2info.pt'.format(model_dir),
-                                          instruct,
                                           configs['allowed_special'])
         self.sample_rate = configs['sample_rate']
         if torch.cuda.is_available() is False and load_jit is True:
@@ -168,3 +155,18 @@ class CosyVoice2(CosyVoice):
         if load_trt:
             self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
         del configs
+
+    def inference_instruct(self, *args, **kwargs):
+        raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
+
+    def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
+        assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
+        for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
+            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
+            start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
+            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()

+ 11 - 35
cosyvoice/cli/frontend.py

@@ -42,7 +42,6 @@ class CosyVoiceFrontEnd:
                  campplus_model: str,
                  speech_tokenizer_model: str,
                  spk2info: str = '',
-                 instruct: bool = False,
                  allowed_special: str = 'all'):
         self.tokenizer = get_tokenizer()
         self.feat_extractor = feat_extractor
@@ -58,9 +57,7 @@ class CosyVoiceFrontEnd:
             self.spk2info = torch.load(spk2info, map_location=self.device)
         else:
             self.spk2info = {}
-        self.instruct = instruct
         self.allowed_special = allowed_special
-        self.inflect_parser = inflect.engine()
         self.use_ttsfrd = use_ttsfrd
         if self.use_ttsfrd:
             self.frd = ttsfrd.TtsFrontendEngine()
@@ -71,6 +68,7 @@ class CosyVoiceFrontEnd:
         else:
             self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
             self.en_tn_model = EnNormalizer()
+            self.inflect_parser = inflect.engine()
 
     def _extract_text_token(self, text):
         text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
@@ -111,15 +109,11 @@ class CosyVoiceFrontEnd:
         if text_frontend is False:
             return [text] if split is True else text
         text = text.strip()
-        # When generating text that contains only punctuation marks or whitespace characters
-        # - Returning empty texts ensures consistent processing logic.
-        if is_only_punctuation(text):
-            return []
-        if contains_chinese(text):
-            if self.use_ttsfrd:
-                texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
-                text = ''.join(texts)
-            else:
+        if self.use_ttsfrd:
+            texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
+            text = ''.join(texts)
+        else:
+            if contains_chinese(text):
                 text = self.zh_tn_model.normalize(text)
                 text = text.replace("\n", "")
                 text = replace_blank(text)
@@ -130,18 +124,13 @@ class CosyVoiceFrontEnd:
                 text = re.sub(r'[,,、]+$', '。', text)
                 texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
                                              token_min_n=60, merge_len=20, comma_split=False))
-        else:
-            if self.use_ttsfrd:
-                texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
-                text = ''.join(texts)
             else:
                 text = self.en_tn_model.normalize(text)
                 text = spell_out_number(text, self.inflect_parser)
                 texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
                                              token_min_n=60, merge_len=20, comma_split=False))
-        if split is False:
-            return text
-        return texts
+        texts = [i for i in texts if not is_only_punctuation(i)]
+        return texts if split is True else text
 
     def frontend_sft(self, tts_text, spk_id):
         tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
@@ -188,22 +177,9 @@ class CosyVoiceFrontEnd:
         return model_input
 
     def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
-        tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
-        prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
-        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
-        speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
-        speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
-        if resample_rate == 24000:
-            # cosyvoice2, force speech_feat % speech_token = 2
-            token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
-            speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
-            speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
-        embedding = self._extract_spk_embedding(prompt_speech_16k)
-        model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
-                       'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
-                       'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
-                       'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
-                       'llm_embedding': embedding, 'flow_embedding': embedding}
+        model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
+        del model_input['llm_prompt_speech_token']
+        del model_input['llm_prompt_speech_token_len']
         return model_input
 
     def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):

+ 2 - 0
cosyvoice/cli/model.py

@@ -316,6 +316,8 @@ class CosyVoice2Model:
         import tensorrt as trt
         with open(flow_decoder_estimator_model, 'rb') as f:
             self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        if self.flow.decoder.estimator_engine is None:
+            raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
         self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
         self.flow.decoder.fp16 = True
 

+ 12 - 0
cosyvoice/utils/class_utils.py

@@ -32,6 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
                                              RelPositionMultiHeadedAttention)
 from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
 from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
+from cosyvoice.llm.llm import TransformerLM, Qwen2LM
+from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
+from cosyvoice.hifigan.generator import HiFTGenerator
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
 
 
 COSYVOICE_ACTIVATION_CLASSES = {
@@ -68,3 +72,11 @@ COSYVOICE_ATTENTION_CLASSES = {
     "selfattn": MultiHeadedAttention,
     "rel_selfattn": RelPositionMultiHeadedAttention,
 }
+
+
+def get_model_type(configs):
+    if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
+        return CosyVoiceModel
+    if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
+        return CosyVoice2Model
+    raise TypeError('No valid model type found!')

+ 2 - 2
webui.py

@@ -69,7 +69,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
         prompt_wav = None
     # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
     if mode_checkbox_group in ['自然语言控制']:
-        if cosyvoice.frontend.instruct is False:
+        if cosyvoice.instruct is False:
             gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
             yield (cosyvoice.sample_rate, default_data)
         if instruct_text == '':
@@ -79,7 +79,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
             gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
     # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
     if mode_checkbox_group in ['跨语种复刻']:
-        if cosyvoice.frontend.instruct is True:
+        if cosyvoice.instruct is True:
             gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
             yield (cosyvoice.sample_rate, default_data)
         if instruct_text != '':