lyuxiang.lx 8 months ago
parent
commit
e3c2400abb

+ 41 - 1
cosyvoice/cli/cosyvoice.py

@@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
 import torch
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
-from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
 from cosyvoice.utils.file_utils import logging
 from cosyvoice.utils.class_utils import get_model_type
 
@@ -192,3 +192,43 @@ class CosyVoice2(CosyVoice):
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
+
+
+class CosyVoice3(CosyVoice):
+
+    def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
+        self.instruct = True if '-Instruct' in model_dir else False
+        self.model_dir = model_dir
+        self.fp16 = fp16
+        if not os.path.exists(model_dir):
+            model_dir = snapshot_download(model_dir)
+        hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
+        if not os.path.exists(hyper_yaml_path):
+            raise ValueError('{} not found!'.format(hyper_yaml_path))
+        with open(hyper_yaml_path, 'r') as f:
+            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
+        assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
+        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+                                          configs['feat_extractor'],
+                                          '{}/campplus.onnx'.format(model_dir),
+                                          '{}/speech_tokenizer_v3.onnx'.format(model_dir),
+                                          '{}/spk2info.pt'.format(model_dir),
+                                          configs['allowed_special'])
+        self.sample_rate = configs['sample_rate']
+        if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
+            load_jit, load_trt, fp16 = False, False, False
+            logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
+        self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
+        self.model.load('{}/llm.pt'.format(model_dir),
+                        '{}/flow.pt'.format(model_dir),
+                        '{}/bigvgan.pt'.format(model_dir))
+        if load_vllm:
+            self.model.load_vllm('{}/vllm'.format(model_dir))
+        if load_jit:
+            self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
+        if load_trt:
+            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
+                                trt_concurrent,
+                                self.fp16)
+        del configs

+ 67 - 0
cosyvoice/cli/model.py

@@ -384,3 +384,70 @@ class CosyVoice2Model(CosyVoiceModel):
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
             torch.cuda.current_stream().synchronize()
+
+
+class CosyVoice3Model(CosyVoice2Model):
+
+    def __init__(self,
+                 llm: torch.nn.Module,
+                 flow: torch.nn.Module,
+                 hift: torch.nn.Module,
+                 fp16: bool = False):
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.llm = llm
+        self.flow = flow
+        self.hift = hift
+        self.fp16 = fp16
+        if self.fp16 is True:
+            self.llm.half()
+            self.flow.half()
+        # NOTE must matching training static_chunk_size
+        self.token_hop_len = 25
+        # hift cache
+        self.mel_cache_len = 8
+        self.source_cache_len = int(self.mel_cache_len * 480)
+        # speech fade in out
+        self.speech_window = np.hamming(2 * self.source_cache_len)
+        # rtf and decoding related
+        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.lock = threading.Lock()
+        # dict used to store session related variable
+        self.tts_speech_token_dict = {}
+        self.llm_end_dict = {}
+        self.hift_cache_dict = {}
+
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
+        with torch.cuda.amp.autocast(self.fp16):
+            tts_mel, _ = self.flow.inference(token=token.to(self.device),
+                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_token=prompt_token.to(self.device),
+                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_feat=prompt_feat.to(self.device),
+                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                             embedding=embedding.to(self.device),
+                                             streaming=stream,
+                                             finalize=finalize)
+        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
+        # append hift cache
+        if self.hift_cache_dict[uuid] is not None:
+            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+        else:
+            hift_cache_source = torch.zeros(1, 1, 0)
+        # keep overlap mel and hift cache
+        if finalize is False:
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
+                                          'source': tts_source[:, :, -self.source_cache_len:],
+                                          'speech': tts_speech[:, -self.source_cache_len:]}
+            tts_speech = tts_speech[:, :-self.source_cache_len]
+        else:
+            if speed != 1.0:
+                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
+                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+        return tts_speech

+ 45 - 0
cosyvoice/hifigan/f0_predictor.py

@@ -17,6 +17,7 @@ try:
     from torch.nn.utils.parametrizations import weight_norm
 except ImportError:
     from torch.nn.utils import weight_norm
+from cosyvoice.transformer.convolution import CausalConv1d
 
 
 class ConvRNNF0Predictor(nn.Module):
@@ -56,3 +57,47 @@ class ConvRNNF0Predictor(nn.Module):
         x = self.condnet(x)
         x = x.transpose(1, 2)
         return torch.abs(self.classifier(x).squeeze(-1))
+
+
+class CausalConvRNNF0Predictor(nn.Module):
+    def __init__(self,
+                 num_class: int = 1,
+                 in_channels: int = 80,
+                 cond_channels: int = 512
+                 ):
+        super().__init__()
+
+        self.num_class = num_class
+        self.condnet = nn.Sequential(
+            weight_norm(
+                CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
+            ),
+            nn.ELU(),
+            weight_norm(
+                CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
+            ),
+            nn.ELU(),
+            weight_norm(
+                CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
+            ),
+            nn.ELU(),
+            weight_norm(
+                CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
+            ),
+            nn.ELU(),
+            weight_norm(
+                CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
+            ),
+            nn.ELU(),
+        )
+        self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
+
+    def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
+        if finalize is True:
+            x = self.condnet[0](x)
+        else:
+            x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
+        for i in range(1, len(self.condnet)):
+            x = self.condnet[i](x)
+        x = x.transpose(1, 2)
+        return torch.abs(self.classifier(x).squeeze(-1))

+ 234 - 71
cosyvoice/hifigan/generator.py

@@ -28,7 +28,7 @@ try:
 except ImportError:
     from torch.nn.utils import weight_norm
 from torch.distributions.uniform import Uniform
-
+from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
 from cosyvoice.transformer.activation import Snake
 from cosyvoice.utils.common import get_padding
 from cosyvoice.utils.common import init_weights
@@ -50,8 +50,10 @@ class ResBlock(torch.nn.Module):
         channels: int = 512,
         kernel_size: int = 3,
         dilations: List[int] = [1, 3, 5],
+        causal: bool = False,
     ):
         super(ResBlock, self).__init__()
+        self.causal = causal
         self.convs1 = nn.ModuleList()
         self.convs2 = nn.ModuleList()
 
@@ -64,7 +66,14 @@ class ResBlock(torch.nn.Module):
                         kernel_size,
                         1,
                         dilation=dilation,
-                        padding=get_padding(kernel_size, dilation)
+                        padding=get_padding(kernel_size, dilation)) if causal is False else
+                    CausalConv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation,
+                        causal_type='left'
                     )
                 )
             )
@@ -76,7 +85,14 @@ class ResBlock(torch.nn.Module):
                         kernel_size,
                         1,
                         dilation=1,
-                        padding=get_padding(kernel_size, 1)
+                        padding=get_padding(kernel_size, 1)) if causal is False else
+                    CausalConv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        causal_type='left'
                     )
                 )
             )
@@ -171,58 +187,6 @@ class SineGen(torch.nn.Module):
         return sine_waves, uv, noise
 
 
-class SourceModuleHnNSF(torch.nn.Module):
-    """ SourceModule for hn-nsf
-    SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
-                 add_noise_std=0.003, voiced_threshod=0)
-    sampling_rate: sampling_rate in Hz
-    harmonic_num: number of harmonic above F0 (default: 0)
-    sine_amp: amplitude of sine source signal (default: 0.1)
-    add_noise_std: std of additive Gaussian noise (default: 0.003)
-        note that amplitude of noise in unvoiced is decided
-        by sine_amp
-    voiced_threshold: threhold to set U/V given F0 (default: 0)
-    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
-    F0_sampled (batchsize, length, 1)
-    Sine_source (batchsize, length, 1)
-    noise_source (batchsize, length 1)
-    uv (batchsize, length, 1)
-    """
-
-    def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
-                 add_noise_std=0.003, voiced_threshod=0):
-        super(SourceModuleHnNSF, self).__init__()
-
-        self.sine_amp = sine_amp
-        self.noise_std = add_noise_std
-
-        # to produce sine waveforms
-        self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
-                                 sine_amp, add_noise_std, voiced_threshod)
-
-        # to merge source harmonics into a single excitation
-        self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
-        self.l_tanh = torch.nn.Tanh()
-
-    def forward(self, x):
-        """
-        Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
-        F0_sampled (batchsize, length, 1)
-        Sine_source (batchsize, length, 1)
-        noise_source (batchsize, length 1)
-        """
-        # source for harmonic branch
-        with torch.no_grad():
-            sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
-            sine_wavs = sine_wavs.transpose(1, 2)
-            uv = uv.transpose(1, 2)
-        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
-
-        # source for noise branch, in the same shape as uv
-        noise = torch.randn_like(uv) * self.sine_amp / 3
-        return sine_merge, noise, uv
-
-
 class SineGen2(torch.nn.Module):
     """ Definition of sine generator
     SineGen(samp_rate, harmonic_num = 0,
@@ -242,7 +206,8 @@ class SineGen2(torch.nn.Module):
     def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
                  sine_amp=0.1, noise_std=0.003,
                  voiced_threshold=0,
-                 flag_for_pulse=False):
+                 flag_for_pulse=False,
+                 causal=False):
         super(SineGen2, self).__init__()
         self.sine_amp = sine_amp
         self.noise_std = noise_std
@@ -252,6 +217,11 @@ class SineGen2(torch.nn.Module):
         self.voiced_threshold = voiced_threshold
         self.flag_for_pulse = flag_for_pulse
         self.upsample_scale = upsample_scale
+        self.causal = causal
+        if causal is True:
+            self.rand_ini = torch.rand(1, 9)
+            self.rand_ini[:, 0] = 0
+            self.sine_waves = torch.rand(1, 60 * 16000, 9)
 
     def _f02uv(self, f0):
         # generate uv signal
@@ -267,9 +237,12 @@ class SineGen2(torch.nn.Module):
         rad_values = (f0_values / self.sampling_rate) % 1
 
         # initial phase noise (no noise for fundamental component)
-        rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
-        rand_ini[:, 0] = 0
-        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+        if self.training is False and self.causal is True:
+            rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
+        else:
+            rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
+            rand_ini[:, 0] = 0
+            rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
 
         # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
         if not self.flag_for_pulse:
@@ -279,7 +252,7 @@ class SineGen2(torch.nn.Module):
 
             phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
             phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
-                                                    scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
+                                                    scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
             sines = torch.sin(phase)
         else:
             # If necessary, make sure that the first time step of every
@@ -331,7 +304,10 @@ class SineGen2(torch.nn.Module):
         #        std = self.sine_amp/3 -> max value ~ self.sine_amp
         # .       for voiced regions is self.noise_std
         noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
-        noise = noise_amp * torch.randn_like(sine_waves)
+        if self.training is False and self.causal is True:
+            noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
+        else:
+            noise = noise_amp * torch.randn_like(sine_waves)
 
         # first: set the unvoiced part to 0 by uv
         # then: additive noise
@@ -339,7 +315,7 @@ class SineGen2(torch.nn.Module):
         return sine_waves, uv, noise
 
 
-class SourceModuleHnNSF2(torch.nn.Module):
+class SourceModuleHnNSF(torch.nn.Module):
     """ SourceModule for hn-nsf
     SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
                  add_noise_std=0.003, voiced_threshod=0)
@@ -358,19 +334,26 @@ class SourceModuleHnNSF2(torch.nn.Module):
     """
 
     def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
-                 add_noise_std=0.003, voiced_threshod=0):
-        super(SourceModuleHnNSF2, self).__init__()
+                 add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
+        super(SourceModuleHnNSF, self).__init__()
 
         self.sine_amp = sine_amp
         self.noise_std = add_noise_std
 
         # to produce sine waveforms
-        self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
-                                  sine_amp, add_noise_std, voiced_threshod)
+        if sinegen_type == '1':
+            self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
+                                    sine_amp, add_noise_std, voiced_threshod)
+        else:
+            self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
+                                    sine_amp, add_noise_std, voiced_threshod, causal=causal)
 
         # to merge source harmonics into a single excitation
         self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
         self.l_tanh = torch.nn.Tanh()
+        self.causal = causal
+        if causal is True:
+            self.uv = torch.rand(1, 60 * 24000, 1)
 
     def forward(self, x):
         """
@@ -385,7 +368,10 @@ class SourceModuleHnNSF2(torch.nn.Module):
         sine_merge = self.l_tanh(self.l_linear(sine_wavs))
 
         # source for noise branch, in the same shape as uv
-        noise = torch.randn_like(uv) * self.sine_amp / 3
+        if self.training is False and self.causal is True:
+            noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
+        else:
+            noise = torch.randn_like(uv) * self.sine_amp / 3
         return sine_merge, noise, uv
 
 
@@ -425,15 +411,16 @@ class HiFTGenerator(nn.Module):
 
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_upsamples = len(upsample_rates)
-        # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
-        this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
-        self.m_source = this_SourceModuleHnNSF(
+        # NOTE in CosyVoice2, we use the original SineGen implementation
+        self.m_source = SourceModuleHnNSF(
             sampling_rate=sampling_rate,
             upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
             harmonic_num=nb_harmonics,
             sine_amp=nsf_alpha,
             add_noise_std=nsf_sigma,
-            voiced_threshod=nsf_voiced_threshold)
+            voiced_threshod=nsf_voiced_threshold,
+            sinegen_type='1' if self.sampling_rate == 22050 else '2',
+            causal=False)
         self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
 
         self.conv_pre = weight_norm(
@@ -580,3 +567,179 @@ class HiFTGenerator(nn.Module):
             s[:, :, :cache_source.shape[2]] = cache_source
         generated_speech = self.decode(x=speech_feat, s=s)
         return generated_speech, s
+
+
+class CausalHiFTGenerator(HiFTGenerator):
+    """
+    HiFTNet Generator: Neural Source Filter + ISTFTNet
+    https://arxiv.org/abs/2309.09493
+    """
+    def __init__(
+            self,
+            in_channels: int = 80,
+            base_channels: int = 512,
+            nb_harmonics: int = 8,
+            sampling_rate: int = 22050,
+            nsf_alpha: float = 0.1,
+            nsf_sigma: float = 0.003,
+            nsf_voiced_threshold: float = 10,
+            upsample_rates: List[int] = [8, 8],
+            upsample_kernel_sizes: List[int] = [16, 16],
+            istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+            resblock_kernel_sizes: List[int] = [3, 7, 11],
+            resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+            source_resblock_kernel_sizes: List[int] = [7, 11],
+            source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
+            lrelu_slope: float = 0.1,
+            audio_limit: float = 0.99,
+            conv_pre_look_right: int = 4,
+            f0_predictor: torch.nn.Module = None,
+    ):
+        torch.nn.Module.__init__(self)
+
+        self.out_channels = 1
+        self.nb_harmonics = nb_harmonics
+        self.sampling_rate = sampling_rate
+        self.istft_params = istft_params
+        self.lrelu_slope = lrelu_slope
+        self.audio_limit = audio_limit
+
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        self.m_source = SourceModuleHnNSF(
+            sampling_rate=sampling_rate,
+            upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
+            harmonic_num=nb_harmonics,
+            sine_amp=nsf_alpha,
+            add_noise_std=nsf_sigma,
+            voiced_threshod=nsf_voiced_threshold,
+            sinegen_type='1' if self.sampling_rate == 22050 else '2',
+            causal=True)
+        self.upsample_rates = upsample_rates
+        self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
+
+        self.conv_pre = weight_norm(
+            CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
+        )
+
+        # Up
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    CausalConv1dUpsample(
+                        base_channels // (2**i),
+                        base_channels // (2**(i + 1)),
+                        k,
+                        u,
+                    )
+                )
+            )
+
+        # Down
+        self.source_downs = nn.ModuleList()
+        self.source_resblocks = nn.ModuleList()
+        downsample_rates = [1] + upsample_rates[::-1][:-1]
+        downsample_cum_rates = np.cumprod(downsample_rates)
+        for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
+            if u == 1:
+                self.source_downs.append(
+                    CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
+                )
+            else:
+                self.source_downs.append(
+                    CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
+                )
+
+            self.source_resblocks.append(
+                ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = base_channels // (2**(i + 1))
+            for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+                self.resblocks.append(ResBlock(ch, k, d, causal=True))
+
+        self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+        self.reflection_pad = nn.ReflectionPad1d((1, 0))
+        self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
+        self.conv_pre_look_right = conv_pre_look_right
+        self.f0_predictor = f0_predictor
+
+    def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
+        s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
+        if finalize is True:
+            x = self.conv_pre(x)
+        else:
+            x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
+            s_stft_real, s_stft_imag = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)], s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
+        s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
+
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, self.lrelu_slope)
+            x = self.ups[i](x)
+
+            if i == self.num_upsamples - 1:
+                x = self.reflection_pad(x)
+
+            # fusion
+            si = self.source_downs[i](s_stft)
+            si = self.source_resblocks[i](si)
+            x = x + si
+
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
+        phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :])  # actually, sin is redundancy
+
+        x = self._istft(magnitude, phase)
+        if finalize is False:
+            x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
+        x = torch.clamp(x, -self.audio_limit, self.audio_limit)
+        return x
+
+    @torch.inference_mode()
+    def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
+        # mel->f0
+        self.f0_predictor.to('cpu')
+        f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
+        # f0->source
+        s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
+        s, _, _ = self.m_source(s)
+        s = s.transpose(1, 2)
+        if finalize is True:
+            generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
+        else:
+            generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
+        return generated_speech, s
+
+
+if __name__ == '__main__':
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False
+    from hyperpyyaml import load_hyperpyyaml
+    with open('./pretrained_models/CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
+        configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
+    model = configs['hift']
+    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+    model.to(device)
+    model.eval()
+    max_len, chunk_size, context_size = 300, 30, 8
+    mel = torch.rand(1, 80, max_len)
+    pred_gt, _ = model.inference(mel)
+    for i in range(0, max_len, chunk_size):
+        finalize = True if i + chunk_size + context_size >= max_len else False
+        pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
+        pred_chunk = pred_chunk[:, i * 480:]
+        print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())

+ 113 - 0
cosyvoice/transformer/convolution.py

@@ -19,6 +19,7 @@ from typing import Tuple
 
 import torch
 from torch import nn
+import torch.nn.functional as F
 
 
 class ConvolutionModule(nn.Module):
@@ -143,3 +144,115 @@ class ConvolutionModule(nn.Module):
             x.masked_fill_(~mask_pad, 0.0)
 
         return x.transpose(1, 2), new_cache
+
+
+# NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
+class CausalConv1d(torch.nn.Conv1d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = 'zeros',
+        causal_type: str = 'left',
+        device=None,
+        dtype=None
+    ) -> None:
+        super(CausalConv1d, self).__init__(in_channels, out_channels,
+                                           kernel_size, stride=1,
+                                           padding=0, dilation=dilation,
+                                           groups=groups, bias=bias,
+                                           padding_mode=padding_mode,
+                                           device=device, dtype=dtype)
+        assert stride == 1
+        self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
+        assert causal_type in ['left', 'right']
+        self.causal_type = causal_type
+
+    def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
+        input_timestep = x.shape[2]
+        if cache.size(2) == 0:
+            cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
+        assert cache.size(2) == self.causal_padding
+        if self.causal_type == 'left':
+            x = torch.concat([cache, x], dim=2)
+        else:
+            x = torch.concat([x, cache], dim=2)
+        x = super(CausalConv1d, self).forward(x)
+        assert x.shape[2] == input_timestep
+        return x
+
+
+class CausalConv1dDownSample(torch.nn.Conv1d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = 'zeros',
+        device=None,
+        dtype=None
+    ) -> None:
+        super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
+                                           kernel_size, stride,
+                                           padding=0, dilation=dilation,
+                                           groups=groups, bias=bias,
+                                           padding_mode=padding_mode,
+                                           device=device, dtype=dtype)
+        assert stride != 1 and dilation == 1
+        assert kernel_size % stride == 0
+        self.causal_padding = stride - 1
+
+    def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+        if cache.size(2) == 0:
+            x = F.pad(x, (self.causal_padding, 0), value=0.0)
+        else:
+            assert cache.size(2) == self.causal_padding
+            x = torch.concat([cache, x], dim=2)
+        x = super(CausalConv1dDownSample, self).forward(x)
+        return x
+
+
+class CausalConv1dUpsample(torch.nn.Conv1d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = 'zeros',
+        device=None,
+        dtype=None
+    ) -> None:
+        super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
+                                           kernel_size, 1,
+                                           padding=0, dilation=dilation,
+                                           groups=groups, bias=bias,
+                                           padding_mode=padding_mode,
+                                           device=device, dtype=dtype)
+        assert dilation == 1
+        self.causal_padding = kernel_size - 1
+        self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
+
+    def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+        x = self.upsample(x)
+        input_timestep = x.shape[2]
+        if cache.size(2) == 0:
+            x = F.pad(x, (self.causal_padding, 0), value=0.0)
+        else:
+            assert cache.size(2) == self.causal_padding
+            x = torch.concat([cache, x], dim=2)
+        x = super(CausalConv1dUpsample, self).forward(x)
+        assert input_timestep == x.shape[2]
+        return x

+ 234 - 0
examples/libritts/cosyvoice3/conf/cosyvoice2.yaml

@@ -0,0 +1,234 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 24000
+llm_input_size: 896
+llm_output_size: 896
+spk_embed_dim: 192
+qwen_pretrain_path: ''
+token_frame_rate: 25
+token_mel_ratio: 2
+
+# stream related params
+chunk_size: 25 # streaming inference chunk size, in token
+num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
+
+# model params
+# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+llm: !new:cosyvoice.llm.llm.Qwen2LM
+    llm_input_size: !ref <llm_input_size>
+    llm_output_size: !ref <llm_output_size>
+    speech_token_size: 6561
+    length_normalized_loss: True
+    lsm_weight: 0
+    mix_ratio: [5, 15]
+    llm: !new:cosyvoice.llm.llm.Qwen2Encoder
+        pretrain_path: !ref <qwen_pretrain_path>
+    sampling: !name:cosyvoice.utils.common.ras_sampling
+        top_p: 0.8
+        top_k: 25
+        win_size: 10
+        tau_r: 0.1
+
+flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
+    input_size: 512
+    output_size: 80
+    spk_embed_dim: !ref <spk_embed_dim>
+    output_type: 'mel'
+    vocab_size: 6561
+    input_frame_rate: !ref <token_frame_rate>
+    only_mask_loss: True
+    token_mel_ratio: !ref <token_mel_ratio>
+    pre_lookahead_len: 3
+    encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
+        output_size: 512
+        attention_heads: 8
+        linear_units: 2048
+        num_blocks: 6
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.1
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        input_size: 512
+        use_cnn_module: False
+        macaron_style: False
+        static_chunk_size: !ref <chunk_size>
+    decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
+        in_channels: 240
+        n_spks: 1
+        spk_emb_dim: 80
+        cfm_params: !new:omegaconf.DictConfig
+            content:
+                sigma_min: 1e-06
+                solver: 'euler'
+                t_scheduler: 'cosine'
+                training_cfg_rate: 0.2
+                inference_cfg_rate: 0.7
+                reg_loss_type: 'l1'
+        estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
+            in_channels: 320
+            out_channels: 80
+            channels: [256]
+            dropout: 0.0
+            attention_head_dim: 64
+            n_blocks: 4
+            num_mid_blocks: 12
+            num_heads: 8
+            act_fn: 'gelu'
+            static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
+            num_decoding_left_chunks: !ref <num_decoding_left_chunks>
+
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+    in_channels: 80
+    base_channels: 512
+    nb_harmonics: 8
+    sampling_rate: !ref <sample_rate>
+    nsf_alpha: 0.1
+    nsf_sigma: 0.003
+    nsf_voiced_threshold: 10
+    upsample_rates: [8, 5, 3]
+    upsample_kernel_sizes: [16, 11, 7]
+    istft_params:
+        n_fft: 16
+        hop_len: 4
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    source_resblock_kernel_sizes: [7, 7, 11]
+    source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    lrelu_slope: 0.1
+    audio_limit: 0.99
+    f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+        num_class: 1
+        in_channels: 80
+        cond_channels: 512
+
+# gan related module
+mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1920
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 480
+    win_size: 1920
+    fmin: 0
+    fmax: null
+    center: False
+hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
+    generator: !ref <hift>
+    discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
+        mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
+    mel_spec_transform: [
+        !ref <mel_spec_transform1>
+    ]
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
+    token_path: !ref <qwen_pretrain_path>
+    skip_special_tokens: True
+allowed_special: 'all'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+    get_tokenizer: !ref <get_tokenizer>
+    allowed_special: !ref <allowed_special>
+filter: !name:cosyvoice.dataset.processor.filter
+    max_length: 40960
+    min_length: 100
+    token_max_length: 200
+    token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+    resample_rate: !ref <sample_rate>
+truncate: !name:cosyvoice.dataset.processor.truncate
+    truncate_length: 24480 # must be a multiplier of hop_size
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1920
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 480
+    win_size: 1920
+    fmin: 0
+    fmax: 8000
+    center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+    feat_extractor: !ref <feat_extractor>
+    token_mel_ratio: 2
+compute_f0: !name:cosyvoice.dataset.processor.compute_f0
+    sample_rate: !ref <sample_rate>
+    hop_size: 480
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+    normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+    shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+    sort_size: 500  # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+    batch_type: 'dynamic'
+    max_frames_in_batch: 2000
+padding: !name:cosyvoice.dataset.processor.padding
+    use_spk_embedding: False # change to True during sft
+
+
+# dataset processor pipeline
+data_pipeline: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <compute_fbank>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+data_pipeline_gan: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <truncate>,
+    !ref <compute_fbank>,
+    !ref <compute_f0>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+
+# llm flow train conf
+train_conf:
+    optim: adam
+    optim_conf:
+        lr: 1e-5 # change to 1e-5 during sft
+    scheduler: constantlr # change to constantlr during sft
+    scheduler_conf:
+        warmup_steps: 2500
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 2
+    log_interval: 100
+    save_per_step: -1
+
+# gan train conf
+train_conf_gan:
+    optim: adam
+    optim_conf:
+        lr: 0.0002 # use small lr for gan training
+    scheduler: constantlr
+    optim_d: adam
+    optim_conf_d:
+        lr: 0.0002 # use small lr for gan training
+    scheduler_d: constantlr
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 1 # in gan training, accum_grad must be 1
+    log_interval: 100
+    save_per_step: -1

+ 42 - 0
examples/libritts/cosyvoice3/conf/ds_stage2.json

@@ -0,0 +1,42 @@
+{
+  "train_micro_batch_size_per_gpu": 1,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 100,
+  "gradient_clipping": 5,
+  "fp16": {
+    "enabled": false,
+    "auto_cast": false,
+    "loss_scale": 0,
+    "initial_scale_power": 16,
+    "loss_scale_window": 256,
+    "hysteresis": 2,
+    "consecutive_hysteresis": false,
+    "min_loss_scale": 1
+  },
+  "bf16": {
+    "enabled": false
+  },
+  "zero_force_ds_cpu_optimizer": false,
+  "zero_optimization": {
+    "stage": 2,
+    "offload_optimizer": {
+      "device": "none",
+      "pin_memory": true
+    },
+    "allgather_partitions": true,
+    "allgather_bucket_size": 5e8,
+    "overlap_comm": false,
+    "reduce_scatter": true,
+    "reduce_bucket_size": 5e8,
+    "contiguous_gradients" : true
+  },
+  "optimizer": {
+    "type": "AdamW",
+    "params": {
+        "lr": 0.001,
+        "weight_decay": 0.0001,
+        "torch_adam": true,
+        "adam_w_mode": true
+    }
+  }
+}

+ 1 - 0
examples/libritts/cosyvoice3/cosyvoice

@@ -0,0 +1 @@
+../../../cosyvoice

+ 1 - 0
examples/libritts/cosyvoice3/local

@@ -0,0 +1 @@
+../cosyvoice/local

+ 1 - 0
examples/libritts/cosyvoice3/path.sh

@@ -0,0 +1 @@
+../cosyvoice/path.sh

+ 111 - 0
examples/libritts/cosyvoice3/run.sh

@@ -0,0 +1,111 @@
+#!/bin/bash
+# Copyright 2024 Alibaba Inc. All Rights Reserved.
+. ./path.sh || exit 1;
+
+stage=-1
+stop_stage=3
+
+data_url=www.openslr.org/resources/60
+data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
+pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+  echo "Data Download"
+  for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
+    local/download_and_untar.sh ${data_dir} ${data_url} ${part}
+  done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+  echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    mkdir -p data/$x
+    python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
+  done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    tools/extract_embedding.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/campplus.onnx
+  done
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    tools/extract_speech_token.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
+  done
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    mkdir -p data/$x/parquet
+    tools/make_parquet_list.py --num_utts_per_parquet 1000 \
+      --num_processes 10 \
+      --src_dir data/$x \
+      --des_dir data/$x/parquet
+  done
+fi
+
+# train llm
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+job_id=1986
+dist_backend="nccl"
+num_workers=2
+prefetch=100
+train_engine=torch_ddp
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+  echo "Run train. We only support llm traning for now"
+  if [ $train_engine == 'deepspeed' ]; then
+    echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
+  fi
+  cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
+  cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
+  # NOTE will update llm/hift training later
+  for model in llm flow hifigan; do
+    torchrun --nnodes=1 --nproc_per_node=$num_gpus \
+        --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
+      cosyvoice/bin/train.py \
+      --train_engine $train_engine \
+      --config conf/cosyvoice2.yaml \
+      --train_data data/train.data.list \
+      --cv_data data/dev.data.list \
+      --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
+      --model $model \
+      --checkpoint $pretrained_model_dir/$model.pt \
+      --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
+      --tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
+      --ddp.dist_backend $dist_backend \
+      --num_workers ${num_workers} \
+      --prefetch ${prefetch} \
+      --pin_memory \
+      --use_amp \
+      --deepspeed_config ./conf/ds_stage2.json \
+      --deepspeed.save_states model+optimizer
+  done
+fi
+
+# average model
+average_num=5
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+  for model in llm flow hifigan; do
+    decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
+    echo "do model average and final checkpoint is $decode_checkpoint"
+    python cosyvoice/bin/average_model.py \
+      --dst_model $decode_checkpoint \
+      --src_path `pwd`/exp/cosyvoice/$model/$train_engine  \
+      --num ${average_num} \
+      --val_best
+  done
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+  echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
+  python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
+  python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
+fi

+ 1 - 0
examples/libritts/cosyvoice3/tools

@@ -0,0 +1 @@
+../../../tools