瀏覽代碼

add hifigan train code

lyuxiang.lx 1 年之前
父節點
當前提交
cb200b21c5

+ 1 - 1
cosyvoice/bin/train.py

@@ -87,7 +87,7 @@ def main():
     logging.basicConfig(level=logging.DEBUG,
                         format='%(asctime)s %(levelname)s %(message)s')
 
-    override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
+    override_dict = {k: None for k in ['llm', 'flow', 'hifigan'] if k != args.model}
     with open(args.config, 'r') as f:
         configs = load_hyperpyyaml(f, overrides=override_dict)
     configs['train_conf'].update(vars(args))

+ 137 - 0
cosyvoice/bin/train_gan.py

@@ -0,0 +1,137 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import argparse
+import datetime
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+from copy import deepcopy
+import torch
+import torch.distributed as dist
+import deepspeed
+
+from hyperpyyaml import load_hyperpyyaml
+
+from torch.distributed.elastic.multiprocessing.errors import record
+
+from cosyvoice.utils.executor_gan import Executor
+from cosyvoice.utils.train_utils import (
+    init_distributed,
+    init_dataset_and_dataloader,
+    init_optimizer_and_scheduler_gan,
+    init_summarywriter, save_model,
+    wrap_cuda_model, check_modify_and_save_config)
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description='training your network')
+    parser.add_argument('--train_engine',
+                        default='torch_ddp',
+                        choices=['torch_ddp', 'deepspeed'],
+                        help='Engine for paralleled training')
+    parser.add_argument('--model', required=True, help='model which will be trained')
+    parser.add_argument('--config', required=True, help='config file')
+    parser.add_argument('--train_data', required=True, help='train data file')
+    parser.add_argument('--cv_data', required=True, help='cv data file')
+    parser.add_argument('--checkpoint', help='checkpoint model')
+    parser.add_argument('--model_dir', required=True, help='save model dir')
+    parser.add_argument('--tensorboard_dir',
+                        default='tensorboard',
+                        help='tensorboard log dir')
+    parser.add_argument('--ddp.dist_backend',
+                        dest='dist_backend',
+                        default='nccl',
+                        choices=['nccl', 'gloo'],
+                        help='distributed backend')
+    parser.add_argument('--num_workers',
+                        default=0,
+                        type=int,
+                        help='num of subprocess workers for reading')
+    parser.add_argument('--prefetch',
+                        default=100,
+                        type=int,
+                        help='prefetch number')
+    parser.add_argument('--pin_memory',
+                        action='store_true',
+                        default=False,
+                        help='Use pinned memory buffers used for reading')
+    parser.add_argument('--deepspeed.save_states',
+                        dest='save_states',
+                        default='model_only',
+                        choices=['model_only', 'model+optimizer'],
+                        help='save model/optimizer states')
+    parser.add_argument('--timeout',
+                        default=30,
+                        type=int,
+                        help='timeout (in seconds) of cosyvoice_join.')
+    parser = deepspeed.add_config_arguments(parser)
+    args = parser.parse_args()
+    return args
+
+
+@record
+def main():
+    args = get_args()
+    logging.basicConfig(level=logging.DEBUG,
+                        format='%(asctime)s %(levelname)s %(message)s')
+
+    override_dict = {k: None for k in ['llm', 'flow', 'hifigan'] if k != args.model}
+    with open(args.config, 'r') as f:
+        configs = load_hyperpyyaml(f, overrides=override_dict, overrides_must_match=False)
+    configs['train_conf'].update(vars(args))
+
+    # Init env for ddp
+    init_distributed(args)
+
+    # Get dataset & dataloader
+    train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
+        init_dataset_and_dataloader(args, configs)
+
+    # Do some sanity checks and save config to arsg.model_dir
+    configs = check_modify_and_save_config(args, configs)
+
+    # Tensorboard summary
+    writer = init_summarywriter(args)
+
+    # load checkpoint
+    model = configs[args.model]
+    if args.checkpoint is not None:
+        model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
+
+    # Dispatch model from cpu to gpu
+    model = wrap_cuda_model(args, model)
+
+    # Get optimizer & scheduler
+    model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler_gan(args, configs, model)
+
+    # Save init checkpoints
+    info_dict = deepcopy(configs['train_conf'])
+    save_model(model, 'init', info_dict)
+
+    # Get executor
+    executor = Executor()
+
+    # Start training loop
+    for epoch in range(info_dict['max_epoch']):
+        executor.epoch = epoch
+        train_dataset.set_epoch(epoch)
+        dist.barrier()
+        group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
+        executor.train_one_epoc(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join)
+        dist.destroy_process_group(group_join)
+
+
+if __name__ == '__main__':
+    main()

+ 54 - 1
cosyvoice/dataset/processor.py

@@ -85,6 +85,7 @@ def filter(data,
     """
     for sample in data:
         sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
+        sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
         del sample['audio_data']
         # sample['wav'] is torch.Tensor, we have 100 frames every second
         num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
@@ -134,6 +135,27 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
         yield sample
 
 
+def truncate(data, truncate_length=24576, mode='train'):
+    """ Truncate data.
+
+        Args:
+            data: Iterable[{key, wav, label, sample_rate}]
+            truncate_length: truncate length
+
+        Returns:
+            Iterable[{key, wav, label, sample_rate}]
+    """
+    for sample in data:
+        waveform = sample['speech']
+        if waveform.shape[1] > truncate_length:
+            start = random.randint(0, waveform.shape[1] - truncate_length)
+            waveform = waveform[:, start: start + truncate_length]
+        else:
+            waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
+        sample['speech'] = waveform
+        yield sample
+
+
 def compute_fbank(data,
                   feat_extractor,
                   mode='train'):
@@ -153,7 +175,26 @@ def compute_fbank(data,
         waveform = sample['speech']
         mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
         sample['speech_feat'] = mat
-        del sample['speech']
+        yield sample
+
+def compute_f0(data, pitch_extractor, mode='train'):
+    """ Extract f0
+
+        Args:
+            data: Iterable[{key, wav, label, sample_rate}]
+
+        Returns:
+            Iterable[{key, feat, label}]
+    """
+    for sample in data:
+        assert 'sample_rate' in sample
+        assert 'speech' in sample
+        assert 'utt' in sample
+        assert 'text_token' in sample
+        waveform = sample['speech']
+        mat = pitch_extractor(waveform).transpose(1, 2)
+        mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
+        sample['pitch_feat'] = mat[0, 0]
         yield sample
 
 
@@ -325,6 +366,9 @@ def padding(data, use_spk_embedding, mode='train'):
         order = torch.argsort(speech_feat_len, descending=True)
 
         utts = [sample[i]['utt'] for i in order]
+        speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
+        speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
+        speech = pad_sequence(speech, batch_first=True, padding_value=0)
         speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
         speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
         speech_token = pad_sequence(speech_token,
@@ -335,6 +379,11 @@ def padding(data, use_spk_embedding, mode='train'):
         speech_feat = pad_sequence(speech_feat,
                                    batch_first=True,
                                    padding_value=0)
+        pitch_feat = [sample[i]['pitch_feat'] for i in order]
+        pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
+        pitch_feat = pad_sequence(pitch_feat,
+                                   batch_first=True,
+                                   padding_value=0)
         text = [sample[i]['text'] for i in order]
         text_token = [torch.tensor(sample[i]['text_token']) for i in order]
         text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
@@ -343,10 +392,14 @@ def padding(data, use_spk_embedding, mode='train'):
         spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
         batch = {
             "utts": utts,
+            "speech": speech,
+            "speech_len": speech_len,
             "speech_token": speech_token,
             "speech_token_len": speech_token_len,
             "speech_feat": speech_feat,
             "speech_feat_len": speech_feat_len,
+            "pitch_feat": pitch_feat,
+            "pitch_feat_len": pitch_feat_len,
             "text": text,
             "text_token": text_token,
             "text_token_len": text_token_len,

+ 139 - 0
cosyvoice/hifigan/discriminator.py

@@ -0,0 +1,139 @@
+from typing import List
+import torch
+import torch.nn as nn
+from torch.nn.utils import weight_norm
+from typing import List, Optional, Tuple
+from einops import rearrange
+from torchaudio.transforms import Spectrogram
+
+class MultipleDiscriminator(nn.Module):
+    def __init__(
+            self, mpd: nn.Module, mrd: nn.Module
+    ):
+        super().__init__()
+        self.mpd = mpd
+        self.mrd = mrd
+
+    def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
+        y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
+        this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
+        y_d_rs += this_y_d_rs
+        y_d_gs += this_y_d_gs
+        fmap_rs += this_fmap_rs
+        fmap_gs += this_fmap_gs
+        this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
+        y_d_rs += this_y_d_rs
+        y_d_gs += this_y_d_gs
+        fmap_rs += this_fmap_rs
+        fmap_gs += this_fmap_gs
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+class MultiResolutionDiscriminator(nn.Module):
+    def __init__(
+        self,
+        fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
+        num_embeddings: Optional[int] = None,
+    ):
+        """
+        Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
+        Additionally, it allows incorporating conditional information with a learned embeddings table.
+
+        Args:
+            fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
+            num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
+                Defaults to None.
+        """
+
+        super().__init__()
+        self.discriminators = nn.ModuleList(
+            [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
+        )
+
+    def forward(
+        self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
+    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+
+        for d in self.discriminators:
+            y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
+            y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorR(nn.Module):
+    def __init__(
+        self,
+        window_length: int,
+        num_embeddings: Optional[int] = None,
+        channels: int = 32,
+        hop_factor: float = 0.25,
+        bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
+    ):
+        super().__init__()
+        self.window_length = window_length
+        self.hop_factor = hop_factor
+        self.spec_fn = Spectrogram(
+            n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
+        )
+        n_fft = window_length // 2 + 1
+        bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+        self.bands = bands
+        convs = lambda: nn.ModuleList(
+            [
+                weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
+                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
+            ]
+        )
+        self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+
+        if num_embeddings is not None:
+            self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
+            torch.nn.init.zeros_(self.emb.weight)
+
+        self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
+
+    def spectrogram(self, x):
+        # Remove DC offset
+        x = x - x.mean(dim=-1, keepdims=True)
+        # Peak normalize the volume of input audio
+        x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+        x = self.spec_fn(x)
+        x = torch.view_as_real(x)
+        x = rearrange(x, "b f t c -> b c t f")
+        # Split into bands
+        x_bands = [x[..., b[0] : b[1]] for b in self.bands]
+        return x_bands
+
+    def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
+        x_bands = self.spectrogram(x)
+        fmap = []
+        x = []
+        for band, stack in zip(x_bands, self.band_convs):
+            for i, layer in enumerate(stack):
+                band = layer(band)
+                band = torch.nn.functional.leaky_relu(band, 0.1)
+                if i > 0:
+                    fmap.append(band)
+            x.append(band)
+        x = torch.cat(x, dim=-1)
+        if cond_embedding_id is not None:
+            emb = self.emb(cond_embedding_id)
+            h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
+        else:
+            h = 0
+        x = self.conv_post(x)
+        fmap.append(x)
+        x += h
+
+        return x, fmap

+ 51 - 38
cosyvoice/hifigan/generator.py

@@ -14,7 +14,7 @@
 
 """HIFI-GAN"""
 
-import typing as tp
+from typing import Dict, Optional, List
 import numpy as np
 from scipy.signal import get_window
 import torch
@@ -46,7 +46,7 @@ class ResBlock(torch.nn.Module):
         self,
         channels: int = 512,
         kernel_size: int = 3,
-        dilations: tp.List[int] = [1, 3, 5],
+        dilations: List[int] = [1, 3, 5],
     ):
         super(ResBlock, self).__init__()
         self.convs1 = nn.ModuleList()
@@ -234,13 +234,13 @@ class HiFTGenerator(nn.Module):
             nsf_alpha: float = 0.1,
             nsf_sigma: float = 0.003,
             nsf_voiced_threshold: float = 10,
-            upsample_rates: tp.List[int] = [8, 8],
-            upsample_kernel_sizes: tp.List[int] = [16, 16],
-            istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
-            resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
-            resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-            source_resblock_kernel_sizes: tp.List[int] = [7, 11],
-            source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
+            upsample_rates: List[int] = [8, 8],
+            upsample_kernel_sizes: List[int] = [16, 16],
+            istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+            resblock_kernel_sizes: List[int] = [3, 7, 11],
+            resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+            source_resblock_kernel_sizes: List[int] = [7, 11],
+            source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
             lrelu_slope: float = 0.1,
             audio_limit: float = 0.99,
             f0_predictor: torch.nn.Module = None,
@@ -316,11 +316,19 @@ class HiFTGenerator(nn.Module):
         self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
         self.f0_predictor = f0_predictor
 
-    def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
-        f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
-
-        har_source, _, _ = self.m_source(f0)
-        return har_source.transpose(1, 2)
+    def remove_weight_norm(self):
+        print('Removing weight norm...')
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+        self.m_source.remove_weight_norm()
+        for l in self.source_downs:
+            remove_weight_norm(l)
+        for l in self.source_resblocks:
+            l.remove_weight_norm()
 
     def _stft(self, x):
         spec = torch.stft(
@@ -338,14 +346,7 @@ class HiFTGenerator(nn.Module):
                                         self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
         return inverse_transform
 
-    def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
-        f0 = self.f0_predictor(x)
-        s = self._f02source(f0)
-
-        # use cache_source to avoid glitch
-        if cache_source.shape[2] != 0:
-            s[:, :, :cache_source.shape[2]] = cache_source
-
+    def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
         s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
         s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
 
@@ -377,22 +378,34 @@ class HiFTGenerator(nn.Module):
 
         x = self._istft(magnitude, phase)
         x = torch.clamp(x, -self.audio_limit, self.audio_limit)
-        return x, s
+        return x
 
-    def remove_weight_norm(self):
-        print('Removing weight norm...')
-        for l in self.ups:
-            remove_weight_norm(l)
-        for l in self.resblocks:
-            l.remove_weight_norm()
-        remove_weight_norm(self.conv_pre)
-        remove_weight_norm(self.conv_post)
-        self.source_module.remove_weight_norm()
-        for l in self.source_downs:
-            remove_weight_norm(l)
-        for l in self.source_resblocks:
-            l.remove_weight_norm()
+    def forward(
+            self,
+            batch: dict,
+            device: torch.device,
+    ) -> Dict[str, Optional[torch.Tensor]]:
+        speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
+        # mel->f0
+        f0 = self.f0_predictor(speech_feat)
+        # f0->source
+        s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
+        s, _, _ = self.m_source(s)
+        s = s.transpose(1, 2)
+        # mel+source->speech
+        generated_speech = self.decode(x=speech_feat, s=s)
+        return generated_speech, f0
 
     @torch.inference_mode()
-    def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
-        return self.forward(x=mel, cache_source=cache_source)
+    def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+        # mel->f0
+        f0 = self.f0_predictor(speech_feat)
+        # f0->source
+        s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
+        s, _, _ = self.m_source(s)
+        s = s.transpose(1, 2)
+        # use cache_source to avoid glitch
+        if cache_source.shape[2] != 0:
+            s[:, :, :cache_source.shape[2]] = cache_source
+        generated_speech = self.decode(x=speech_feat, s=s)
+        return generated_speech, s

+ 66 - 0
cosyvoice/hifigan/hifigan.py

@@ -0,0 +1,66 @@
+from typing import Dict, Optional
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
+from cosyvoice.utils.losses import tpr_loss, mel_loss
+
+class HiFiGan(nn.Module):
+    def __init__(self, generator, discriminator, mel_spec_transform,
+                 multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
+                 tpr_loss_weight=1.0, tpr_loss_tau=0.04):
+        super(HiFiGan, self).__init__()
+        self.generator = generator
+        self.discriminator = discriminator
+        self.mel_spec_transform = mel_spec_transform
+        self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
+        self.feat_match_loss_weight = feat_match_loss_weight
+        self.tpr_loss_weight = tpr_loss_weight
+        self.tpr_loss_tau = tpr_loss_tau
+
+    def forward(
+            self,
+            batch: dict,
+            device: torch.device,
+    ) -> Dict[str, Optional[torch.Tensor]]:
+        if batch['turn'] == 'generator':
+            return self.forward_generator(batch, device)
+        else:
+            return self.forward_discriminator(batch, device)
+
+    def forward_generator(self, batch, device):
+        real_speech = batch['speech'].to(device)
+        pitch_feat = batch['pitch_feat'].to(device)
+        # 1. calculate generator outputs
+        generated_speech, generated_f0 = self.generator(batch, device)
+        # 2. calculate discriminator outputs
+        y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
+        # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
+        loss_gen, _ = generator_loss(y_d_gs)
+        loss_fm = feature_loss(fmap_rs, fmap_gs)
+        loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
+        if self.tpr_loss_weight != 0:
+            loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
+        else:
+            loss_tpr = torch.zeros(1).to(device)
+        loss_f0 = F.l1_loss(generated_f0, pitch_feat)
+        loss = loss_gen + self.feat_match_loss_weight * loss_fm + self.multi_mel_spectral_recon_loss_weight * loss_mel + self.tpr_loss_weight * loss_tpr + loss_f0
+        return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
+
+    def forward_discriminator(self, batch, device):
+        real_speech = batch['speech'].to(device)
+        pitch_feat = batch['pitch_feat'].to(device)
+        # 1. calculate generator outputs
+        with torch.no_grad():
+            generated_speech, generated_f0 = self.generator(batch, device)
+        # 2. calculate discriminator outputs
+        y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
+        # 3. calculate discriminator losses, tpr losses [Optional]
+        loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
+        if self.tpr_loss_weight != 0:
+            loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
+        else:
+            loss_tpr = torch.zeros(1).to(device)
+        loss_f0 = F.l1_loss(generated_f0, pitch_feat)
+        loss = loss_disc + self.tpr_loss_weight * loss_tpr + loss_f0
+        return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}

+ 118 - 0
cosyvoice/utils/executor_gan.py

@@ -0,0 +1,118 @@
+# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
+#               2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import nullcontext
+import os
+
+import torch
+import torch.distributed as dist
+
+from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
+
+
+class Executor:
+
+    def __init__(self):
+        self.step = 0
+        self.epoch = 0
+        self.rank = int(os.environ.get('RANK', 0))
+        self.device = torch.device('cuda:{}'.format(self.rank))
+
+    def train_one_epoc(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, writer, info_dict, group_join):
+        ''' Train one epoch
+        '''
+
+        lr = optimizer.param_groups[0]['lr']
+        logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
+        logging.info('using accumulate grad, new batch size is {} times'
+                     ' larger than before'.format(info_dict['accum_grad']))
+        # A context manager to be used in conjunction with an instance of
+        # torch.nn.parallel.DistributedDataParallel to be able to train
+        # with uneven inputs across participating processes.
+        model.train()
+        model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
+        with model_context():
+            for batch_idx, batch_dict in enumerate(train_data_loader):
+                info_dict["tag"] = "TRAIN"
+                info_dict["step"] = self.step
+                info_dict["epoch"] = self.epoch
+                info_dict["batch_idx"] = batch_idx
+                if cosyvoice_join(group_join, info_dict):
+                    break
+
+                # Disable gradient synchronizations across DDP processes.
+                # Within this context, gradients will be accumulated on module
+                # variables, which will later be synchronized.
+                if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
+                    context = model.no_sync
+                # Used for single gpu training and DDP gradient synchronization
+                # processes.
+                else:
+                    context = nullcontext
+
+                with context():
+                    batch_dict['turn'] = 'discriminator'
+                    info_dict = batch_forward(model, batch_dict, info_dict)
+                    info_dict = batch_backward(model, info_dict)
+                info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, info_dict)
+                log_per_step(writer, info_dict)
+                with context():
+                    batch_dict['turn'] = 'generator'
+                    info_dict = batch_forward(model, batch_dict, info_dict)
+                    info_dict = batch_backward(model, info_dict)
+                info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
+                log_per_step(writer, info_dict)
+                # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
+                if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
+                   (batch_idx + 1) % info_dict["accum_grad"] == 0:
+                    dist.barrier()
+                    self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
+                    model.train()
+                if (batch_idx + 1) % info_dict["accum_grad"] == 0:
+                    self.step += 1
+        dist.barrier()
+        self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
+
+    @torch.inference_mode()
+    def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
+        ''' Cross validation on
+        '''
+        logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
+        model.eval()
+        total_num_utts, total_loss_dict = 0, {}  # avoid division by 0
+        for batch_idx, batch_dict in enumerate(cv_data_loader):
+            info_dict["tag"] = "CV"
+            info_dict["step"] = self.step
+            info_dict["epoch"] = self.epoch
+            info_dict["batch_idx"] = batch_idx
+
+            num_utts = len(batch_dict["utts"])
+            total_num_utts += num_utts
+
+            batch_dict['turn'] = 'generator'
+            info_dict = batch_forward(model, batch_dict, info_dict)
+
+            for k, v in info_dict['loss_dict'].items():
+                if k not in total_loss_dict:
+                    total_loss_dict[k] = []
+                total_loss_dict[k].append(v.item() * num_utts)
+            log_per_step(None, info_dict)
+        for k, v in total_loss_dict.items():
+            total_loss_dict[k] = sum(v) / total_num_utts
+        info_dict['loss_dict'] = total_loss_dict
+        log_per_save(writer, info_dict)
+        model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
+        save_model(model, model_name, info_dict)

+ 18 - 0
cosyvoice/utils/losses.py

@@ -0,0 +1,18 @@
+import torch
+import torch.nn.functional as F
+
+def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
+    loss = 0
+    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+        m_DG = torch.median((dr - dg))
+        L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
+        loss += tau - F.relu(tau - L_rel)
+    return loss
+
+def mel_loss(real_speech, generated_speech, mel_transforms):
+    loss = 0
+    for transform in mel_transforms:
+        mel_r = transform(real_speech)
+        mel_g = transform(generated_speech)
+        loss += F.l1_loss(mel_g, mel_r)
+    return loss

+ 43 - 0
cosyvoice/utils/train_utils.py

@@ -142,6 +142,49 @@ def init_optimizer_and_scheduler(args, configs, model):
     return model, optimizer, scheduler
 
 
+def init_optimizer_and_scheduler_gan(args, configs, model):
+    if configs['train_conf']['optim'] == 'adam':
+        optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
+    elif configs['train_conf']['optim'] == 'adamw':
+        optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
+    else:
+        raise ValueError("unknown optimizer: " + configs['train_conf'])
+
+    if configs['train_conf']['scheduler'] == 'warmuplr':
+        scheduler_type = WarmupLR
+        scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
+    elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
+        scheduler_type = NoamHoldAnnealing
+        scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
+    elif configs['train_conf']['scheduler'] == 'constantlr':
+        scheduler_type = ConstantLR
+        scheduler = ConstantLR(optimizer)
+    else:
+        raise ValueError("unknown scheduler: " + configs['train_conf'])
+
+    if configs['train_conf']['optim_d'] == 'adam':
+        optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
+    elif configs['train_conf']['optim_d'] == 'adamw':
+        optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
+    else:
+        raise ValueError("unknown optimizer: " + configs['train_conf'])
+
+    if configs['train_conf']['scheduler_d'] == 'warmuplr':
+        scheduler_type = WarmupLR
+        scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
+    elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
+        scheduler_type = NoamHoldAnnealing
+        scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
+    elif configs['train_conf']['scheduler'] == 'constantlr':
+        scheduler_type = ConstantLR
+        scheduler_d = ConstantLR(optimizer_d)
+    else:
+        raise ValueError("unknown scheduler: " + configs['train_conf'])
+
+    # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
+    return model, optimizer, scheduler, optimizer_d, scheduler_d
+
+
 def init_summarywriter(args):
     writer = None
     if int(os.environ.get('RANK', 0)) == 0:

+ 141 - 0
examples/libritts/cosyvoice/conf/cosyvoice.hifigan.yaml

@@ -0,0 +1,141 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 22050
+text_encoder_input_size: 512
+llm_input_size: 1024
+llm_output_size: 1024
+spk_embed_dim: 192
+
+# model params
+# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+    in_channels: 80
+    base_channels: 512
+    nb_harmonics: 8
+    sampling_rate: !ref <sample_rate>
+    nsf_alpha: 0.1
+    nsf_sigma: 0.003
+    nsf_voiced_threshold: 10
+    upsample_rates: [8, 8]
+    upsample_kernel_sizes: [16, 16]
+    istft_params:
+        n_fft: 16
+        hop_len: 4
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    source_resblock_kernel_sizes: [7, 11]
+    source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
+    lrelu_slope: 0.1
+    audio_limit: 0.99
+    f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+        num_class: 1
+        in_channels: 80
+        cond_channels: 512
+
+mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
+    generator: !ref <hift>
+    discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
+        mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
+    mel_spec_transform: [
+        !ref <mel_spec_transform1>
+    ]
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
+    multilingual: True
+    num_languages: 100
+    language: 'en'
+    task: 'transcribe'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+    get_tokenizer: !ref <get_tokenizer>
+    allowed_special: 'all'
+filter: !name:cosyvoice.dataset.processor.filter
+    max_length: 40960
+    min_length: 0
+    token_max_length: 200
+    token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+    resample_rate: !ref <sample_rate>
+truncate: !name:cosyvoice.dataset.processor.truncate
+    truncate_length: 24576 # must be a multiplier of hop_size
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+    feat_extractor: !ref <feat_extractor>
+pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch
+    sample_rate: !ref <sample_rate>
+    frame_length: 46.4 # match feat_extractor win_size/sampling_rate
+    frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate
+compute_f0: !name:cosyvoice.dataset.processor.compute_f0
+    pitch_extractor: !ref <pitch_extractor>
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+    normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+    shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+    sort_size: 500  # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+    batch_type: 'dynamic'
+    max_frames_in_batch: 1200
+padding: !name:cosyvoice.dataset.processor.padding
+    use_spk_embedding: False # change to True during sft
+
+# dataset processor pipeline
+data_pipeline: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <truncate>,
+    !ref <compute_fbank>,
+    !ref <compute_f0>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+
+# train conf
+train_conf:
+    optim: adam
+    optim_conf:
+        lr: 0.002 # change to 0.001 if you want to train flow from scratch
+    scheduler: warmuplr
+    scheduler_conf:
+        warmup_steps: 25000
+    optim_d: adam
+    optim_conf_d:
+        lr: 0.002 # change to 0.001 if you want to train flow from scratch
+    scheduler_d: warmuplr
+    scheduler_conf_d:
+        warmup_steps: 25000
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 2
+    log_interval: 100
+    save_per_step: -1