lyuxiang.lx 8 months ago
parent
commit
8c96081f94

+ 5 - 2
cosyvoice/bin/export_onnx.py

@@ -27,7 +27,7 @@ from tqdm import tqdm
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2, CosyVoice3
 from cosyvoice.utils.file_utils import logging
 
 
@@ -64,7 +64,10 @@ def main():
         try:
             model = CosyVoice2(args.model_dir)
         except Exception:
-            raise TypeError('no valid model_type!')
+            try:
+                model = CosyVoice3(args.model_dir)
+            except Exception:
+                raise TypeError('no valid model_type!')
 
     # 1. export flow decoder estimator
     estimator = model.model.flow.decoder.estimator

+ 1 - 1
cosyvoice/cli/cosyvoice.py

@@ -221,7 +221,7 @@ class CosyVoice3(CosyVoice):
         self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
-                        '{}/bigvgan.pt'.format(model_dir))
+                        '{}/hift.pt'.format(model_dir))
         if load_vllm:
             self.model.load_vllm('{}/vllm'.format(model_dir))
         if load_jit:

+ 1 - 1
cosyvoice/cli/model.py

@@ -447,7 +447,7 @@ class CosyVoice3Model(CosyVoice2Model):
             if speed != 1.0:
                 assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                 tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
-            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel)
             if self.hift_cache_dict[uuid] is not None:
                 tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
         return tts_speech

+ 2 - 0
cosyvoice/flow/DiT/dit.py

@@ -115,6 +115,7 @@ class DiT(nn.Module):
         mu_dim=None,
         long_skip_connection=False,
         spk_dim=None,
+        out_channels=None,
         static_chunk_size=50,
         num_decoding_left_chunks=2
     ):
@@ -137,6 +138,7 @@ class DiT(nn.Module):
 
         self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
         self.proj_out = nn.Linear(dim, mel_dim)
+        self.out_channels = out_channels
         self.static_chunk_size = static_chunk_size
         self.num_decoding_left_chunks = num_decoding_left_chunks
 

+ 4 - 2
cosyvoice/utils/class_utils.py

@@ -33,8 +33,8 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
 from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
 from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
 from cosyvoice.llm.llm import TransformerLM, Qwen2LM
-from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
-from cosyvoice.hifigan.generator import HiFTGenerator
+from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
+from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
 from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
 
 
@@ -80,4 +80,6 @@ def get_model_type(configs):
         return CosyVoiceModel
     if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
         return CosyVoice2Model
+    if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
+        return CosyVoice2Model
     raise TypeError('No valid model type found!')