Browse Source

fix export_jit.py

lyuxiang.lx 1 month ago
parent
commit
2db78e7058
3 changed files with 4 additions and 5 deletions
  1. 2 3
      cosyvoice/bin/export_jit.py
  2. 1 1
      cosyvoice/cli/cosyvoice.py
  3. 1 1
      cosyvoice/cli/model.py

+ 2 - 3
cosyvoice/bin/export_jit.py

@@ -24,7 +24,6 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
 from cosyvoice.cli.cosyvoice import AutoModel
-from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
 from cosyvoice.utils.file_utils import logging
 
 
@@ -60,7 +59,7 @@ def main():
 
     model = AutoModel(model_dir=args.model_dir)
 
-    if isinstance(model.model, CosyVoiceModel):
+    if model.__class__.__name__ == 'CosyVoice':
         # 1. export llm text_encoder
         llm_text_encoder = model.model.llm.text_encoder
         script = get_optimized_script(llm_text_encoder)
@@ -84,7 +83,7 @@ def main():
         script = get_optimized_script(flow_encoder.half())
         script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
         logging.info('successfully export flow_encoder')
-    elif isinstance(model.model, CosyVoice2Model):
+    elif model.__class__.__name__ == 'CosyVoice2':
         # 1. export flow encoder
         flow_encoder = model.model.flow.encoder
         script = get_optimized_script(flow_encoder)

+ 1 - 1
cosyvoice/cli/cosyvoice.py

@@ -114,7 +114,7 @@ class CosyVoice:
                 start_time = time.time()
 
     def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
-        assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
+        assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
         instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
             model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)

+ 1 - 1
cosyvoice/cli/model.py

@@ -100,7 +100,7 @@ class CosyVoiceModel:
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
             if isinstance(text, Generator):
-                assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
+                assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
                 for i in self.llm.inference_bistream(text=text,
                                                      prompt_text=prompt_text.to(self.device),
                                                      prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),