Parcourir la source

Merge pull request #494 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
Xiang Lyu il y a 1 an
Parent
commit
de76577a9f

+ 1 - 1
.github/workflows/lint.yml

@@ -52,5 +52,5 @@ jobs:
           set -eux
           pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
           flake8 --version
-          flake8 --max-line-length 150 --ignore B006,B008,B905,C408,E402,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
+          flake8 --max-line-length 150 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
           if [ $? != 0 ]; then exit 1; fi

+ 1 - 3
README.md

@@ -26,9 +26,7 @@ For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVo
 
     - [ ] 25hz llama based llm model which supports lora finetune
     - [ ] Support more instruction mode
-    - [ ] Voice conversion
     - [ ] Music generation
-    - [ ] Training script sample based on Mandarin
     - [ ] CosyVoice-500M trained with more multi-lingual data
     - [ ] More...
 
@@ -113,7 +111,7 @@ from cosyvoice.cli.cosyvoice import CosyVoice
 from cosyvoice.utils.file_utils import load_wav
 import torchaudio
 
-cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
+cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
 # sft usage
 print(cosyvoice.list_avaliable_spks())
 # change stream=True for chunk stream inference

+ 92 - 0
cosyvoice/bin/average_model.py

@@ -0,0 +1,92 @@
+# Copyright (c) 2020 Mobvoi Inc (Di Wu)
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import glob
+
+import yaml
+import torch
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description='average model')
+    parser.add_argument('--dst_model', required=True, help='averaged model')
+    parser.add_argument('--src_path',
+                        required=True,
+                        help='src model path for average')
+    parser.add_argument('--val_best',
+                        action="store_true",
+                        help='averaged model')
+    parser.add_argument('--num',
+                        default=5,
+                        type=int,
+                        help='nums for averaged model')
+
+    args = parser.parse_args()
+    print(args)
+    return args
+
+
+def main():
+    args = get_args()
+    val_scores = []
+    if args.val_best:
+        yamls = glob.glob('{}/*.yaml'.format(args.src_path))
+        yamls = [
+            f for f in yamls
+            if not (os.path.basename(f).startswith('train')
+                    or os.path.basename(f).startswith('init'))
+        ]
+        for y in yamls:
+            with open(y, 'r') as f:
+                dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
+                loss = float(dic_yaml['loss_dict']['loss'])
+                epoch = int(dic_yaml['epoch'])
+                step = int(dic_yaml['step'])
+                tag = dic_yaml['tag']
+                val_scores += [[epoch, step, loss, tag]]
+        sorted_val_scores = sorted(val_scores,
+                                   key=lambda x: x[2],
+                                   reverse=False)
+        print("best val (epoch, step, loss, tag) = " +
+              str(sorted_val_scores[:args.num]))
+        path_list = [
+            args.src_path + '/epoch_{}_whole.pt'.format(score[0])
+            for score in sorted_val_scores[:args.num]
+        ]
+    print(path_list)
+    avg = {}
+    num = args.num
+    assert num == len(path_list)
+    for path in path_list:
+        print('Processing {}'.format(path))
+        states = torch.load(path, map_location=torch.device('cpu'))
+        for k in states.keys():
+            if k not in avg.keys():
+                avg[k] = states[k].clone()
+            else:
+                avg[k] += states[k]
+    # average
+    for k in avg.keys():
+        if avg[k] is not None:
+            # pytorch 1.6 use true_divide instead of /=
+            avg[k] = torch.true_divide(avg[k], num)
+    print('Saving to {}'.format(args.dst_model))
+    torch.save(avg, args.dst_model)
+
+
+if __name__ == '__main__':
+    main()

+ 19 - 7
cosyvoice/bin/train.py

@@ -18,6 +18,7 @@ import datetime
 import logging
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 from copy import deepcopy
+import os
 import torch
 import torch.distributed as dist
 import deepspeed
@@ -73,7 +74,7 @@ def get_args():
                         choices=['model_only', 'model+optimizer'],
                         help='save model/optimizer states')
     parser.add_argument('--timeout',
-                        default=30,
+                        default=60,
                         type=int,
                         help='timeout (in seconds) of cosyvoice_join.')
     parser = deepspeed.add_config_arguments(parser)
@@ -86,8 +87,12 @@ def main():
     args = get_args()
     logging.basicConfig(level=logging.DEBUG,
                         format='%(asctime)s %(levelname)s %(message)s')
+    # gan train has some special initialization logic
+    gan = True if args.model == 'hifigan' else False
 
-    override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
+    override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
+    if gan is True:
+        override_dict.pop('hift')
     with open(args.config, 'r') as f:
         configs = load_hyperpyyaml(f, overrides=override_dict)
     configs['train_conf'].update(vars(args))
@@ -97,7 +102,7 @@ def main():
 
     # Get dataset & dataloader
     train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
-        init_dataset_and_dataloader(args, configs)
+        init_dataset_and_dataloader(args, configs, gan)
 
     # Do some sanity checks and save config to arsg.model_dir
     configs = check_modify_and_save_config(args, configs)
@@ -108,20 +113,23 @@ def main():
     # load checkpoint
     model = configs[args.model]
     if args.checkpoint is not None:
-        model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
+        if os.path.exists(args.checkpoint):
+            model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
+        else:
+            logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
 
     # Dispatch model from cpu to gpu
     model = wrap_cuda_model(args, model)
 
     # Get optimizer & scheduler
-    model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
+    model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
 
     # Save init checkpoints
     info_dict = deepcopy(configs['train_conf'])
     save_model(model, 'init', info_dict)
 
     # Get executor
-    executor = Executor()
+    executor = Executor(gan=gan)
 
     # Start training loop
     for epoch in range(info_dict['max_epoch']):
@@ -129,7 +137,11 @@ def main():
         train_dataset.set_epoch(epoch)
         dist.barrier()
         group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
-        executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
+        if gan is True:
+            executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
+                                        writer, info_dict, group_join)
+        else:
+            executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
         dist.destroy_process_group(group_join)
 
 

+ 2 - 2
cosyvoice/cli/cosyvoice.py

@@ -23,7 +23,7 @@ from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=True, load_onnx=False):
+    def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
@@ -37,7 +37,7 @@ class CosyVoice:
                                           '{}/spk2info.pt'.format(model_dir),
                                           instruct,
                                           configs['allowed_special'])
-        self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
+        self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))

+ 13 - 5
cosyvoice/cli/model.py

@@ -26,11 +26,13 @@ class CosyVoiceModel:
     def __init__(self,
                  llm: torch.nn.Module,
                  flow: torch.nn.Module,
-                 hift: torch.nn.Module):
+                 hift: torch.nn.Module,
+                 fp16: bool):
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         self.llm = llm
         self.flow = flow
         self.hift = hift
+        self.fp16 = fp16
         self.token_min_hop_len = 2 * self.flow.input_frame_rate
         self.token_max_hop_len = 4 * self.flow.input_frame_rate
         self.token_overlap_len = 20
@@ -56,13 +58,17 @@ class CosyVoiceModel:
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
         self.llm.to(self.device).eval()
-        self.llm.half()
+        if self.fp16 is True:
+            self.llm.half()
         self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
         self.flow.to(self.device).eval()
-        self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
+        # in case hift_model is a hifigan model
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device)}
+        self.hift.load_state_dict(hift_state_dict, strict=False)
         self.hift.to(self.device).eval()
 
     def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
+        assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
         llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
         self.llm.text_encoder = llm_text_encoder
         llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
@@ -80,6 +86,8 @@ class CosyVoiceModel:
         self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
 
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
+        if self.fp16 is True:
+            llm_embedding = llm_embedding.half()
         with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),
                                         text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
@@ -87,7 +95,7 @@ class CosyVoiceModel:
                                         prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                         prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                         prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                        embedding=llm_embedding.to(self.device).half()):
+                                        embedding=llm_embedding.to(self.device)):
                 self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
@@ -123,7 +131,7 @@ class CosyVoiceModel:
             if speed != 1.0:
                 assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                 tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
-            tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
             if self.hift_cache_dict[uuid] is not None:
                 tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
         return tts_speech

+ 5 - 1
cosyvoice/dataset/dataset.py

@@ -126,6 +126,7 @@ class DataList(IterableDataset):
 def Dataset(data_list_file,
             data_pipeline,
             mode='train',
+            gan=False,
             shuffle=True,
             partition=True,
             tts_file='',
@@ -153,8 +154,11 @@ def Dataset(data_list_file,
                        shuffle=shuffle,
                        partition=partition)
     if mode == 'inference':
-        # map partial arg tts_data in inference mode
+        # map partial arg to parquet_opener func in inference mode
         data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
+    if gan is True:
+        # map partial arg to padding func in gan mode
+        data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
     for func in data_pipeline:
         dataset = Processor(dataset, func, mode=mode)
     return dataset

+ 62 - 2
cosyvoice/dataset/processor.py

@@ -85,6 +85,7 @@ def filter(data,
     """
     for sample in data:
         sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
+        sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
         del sample['audio_data']
         # sample['wav'] is torch.Tensor, we have 100 frames every second
         num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
@@ -134,6 +135,27 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
         yield sample
 
 
+def truncate(data, truncate_length=24576, mode='train'):
+    """ Truncate data.
+
+        Args:
+            data: Iterable[{key, wav, label, sample_rate}]
+            truncate_length: truncate length
+
+        Returns:
+            Iterable[{key, wav, label, sample_rate}]
+    """
+    for sample in data:
+        waveform = sample['speech']
+        if waveform.shape[1] > truncate_length:
+            start = random.randint(0, waveform.shape[1] - truncate_length)
+            waveform = waveform[:, start: start + truncate_length]
+        else:
+            waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
+        sample['speech'] = waveform
+        yield sample
+
+
 def compute_fbank(data,
                   feat_extractor,
                   mode='train'):
@@ -153,7 +175,27 @@ def compute_fbank(data,
         waveform = sample['speech']
         mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
         sample['speech_feat'] = mat
-        del sample['speech']
+        yield sample
+
+
+def compute_f0(data, pitch_extractor, mode='train'):
+    """ Extract f0
+
+        Args:
+            data: Iterable[{key, wav, label, sample_rate}]
+
+        Returns:
+            Iterable[{key, feat, label}]
+    """
+    for sample in data:
+        assert 'sample_rate' in sample
+        assert 'speech' in sample
+        assert 'utt' in sample
+        assert 'text_token' in sample
+        waveform = sample['speech']
+        mat = pitch_extractor(waveform).transpose(1, 2)
+        mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
+        sample['pitch_feat'] = mat[0, 0]
         yield sample
 
 
@@ -309,7 +351,7 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, m
             logging.fatal('Unsupported batch type {}'.format(batch_type))
 
 
-def padding(data, use_spk_embedding, mode='train'):
+def padding(data, use_spk_embedding, mode='train', gan=False):
     """ Padding the data into training data
 
         Args:
@@ -325,6 +367,9 @@ def padding(data, use_spk_embedding, mode='train'):
         order = torch.argsort(speech_feat_len, descending=True)
 
         utts = [sample[i]['utt'] for i in order]
+        speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
+        speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
+        speech = pad_sequence(speech, batch_first=True, padding_value=0)
         speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
         speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
         speech_token = pad_sequence(speech_token,
@@ -343,6 +388,8 @@ def padding(data, use_spk_embedding, mode='train'):
         spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
         batch = {
             "utts": utts,
+            "speech": speech,
+            "speech_len": speech_len,
             "speech_token": speech_token,
             "speech_token_len": speech_token_len,
             "speech_feat": speech_feat,
@@ -353,6 +400,19 @@ def padding(data, use_spk_embedding, mode='train'):
             "utt_embedding": utt_embedding,
             "spk_embedding": spk_embedding,
         }
+        if gan is True:
+            # in gan train, we need pitch_feat
+            pitch_feat = [sample[i]['pitch_feat'] for i in order]
+            pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
+            pitch_feat = pad_sequence(pitch_feat,
+                                      batch_first=True,
+                                      padding_value=0)
+            batch["pitch_feat"] = pitch_feat
+            batch["pitch_feat_len"] = pitch_feat_len
+        else:
+            # only gan train needs speech, delete it to save memory
+            del batch["speech"]
+            del batch["speech_len"]
         if mode == 'inference':
             tts_text = [sample[i]['tts_text'] for i in order]
             tts_index = [sample[i]['tts_index'] for i in order]

+ 140 - 0
cosyvoice/hifigan/discriminator.py

@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+from torch.nn.utils import weight_norm
+from typing import List, Optional, Tuple
+from einops import rearrange
+from torchaudio.transforms import Spectrogram
+
+
+class MultipleDiscriminator(nn.Module):
+    def __init__(
+            self, mpd: nn.Module, mrd: nn.Module
+    ):
+        super().__init__()
+        self.mpd = mpd
+        self.mrd = mrd
+
+    def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
+        y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
+        this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
+        y_d_rs += this_y_d_rs
+        y_d_gs += this_y_d_gs
+        fmap_rs += this_fmap_rs
+        fmap_gs += this_fmap_gs
+        this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
+        y_d_rs += this_y_d_rs
+        y_d_gs += this_y_d_gs
+        fmap_rs += this_fmap_rs
+        fmap_gs += this_fmap_gs
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class MultiResolutionDiscriminator(nn.Module):
+    def __init__(
+        self,
+        fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
+        num_embeddings: Optional[int] = None,
+    ):
+        """
+        Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
+        Additionally, it allows incorporating conditional information with a learned embeddings table.
+
+        Args:
+            fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
+            num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
+                Defaults to None.
+        """
+
+        super().__init__()
+        self.discriminators = nn.ModuleList(
+            [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
+        )
+
+    def forward(
+        self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
+    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+
+        for d in self.discriminators:
+            y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
+            y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorR(nn.Module):
+    def __init__(
+        self,
+        window_length: int,
+        num_embeddings: Optional[int] = None,
+        channels: int = 32,
+        hop_factor: float = 0.25,
+        bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
+    ):
+        super().__init__()
+        self.window_length = window_length
+        self.hop_factor = hop_factor
+        self.spec_fn = Spectrogram(
+            n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
+        )
+        n_fft = window_length // 2 + 1
+        bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+        self.bands = bands
+        convs = lambda: nn.ModuleList(
+            [
+                weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
+                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
+            ]
+        )
+        self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+
+        if num_embeddings is not None:
+            self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
+            torch.nn.init.zeros_(self.emb.weight)
+
+        self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
+
+    def spectrogram(self, x):
+        # Remove DC offset
+        x = x - x.mean(dim=-1, keepdims=True)
+        # Peak normalize the volume of input audio
+        x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+        x = self.spec_fn(x)
+        x = torch.view_as_real(x)
+        x = rearrange(x, "b f t c -> b c t f")
+        # Split into bands
+        x_bands = [x[..., b[0]: b[1]] for b in self.bands]
+        return x_bands
+
+    def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
+        x_bands = self.spectrogram(x)
+        fmap = []
+        x = []
+        for band, stack in zip(x_bands, self.band_convs):
+            for i, layer in enumerate(stack):
+                band = layer(band)
+                band = torch.nn.functional.leaky_relu(band, 0.1)
+                if i > 0:
+                    fmap.append(band)
+            x.append(band)
+        x = torch.cat(x, dim=-1)
+        if cond_embedding_id is not None:
+            emb = self.emb(cond_embedding_id)
+            h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
+        else:
+            h = 0
+        x = self.conv_post(x)
+        fmap.append(x)
+        x += h
+
+        return x, fmap

+ 51 - 38
cosyvoice/hifigan/generator.py

@@ -14,7 +14,7 @@
 
 """HIFI-GAN"""
 
-import typing as tp
+from typing import Dict, Optional, List
 import numpy as np
 from scipy.signal import get_window
 import torch
@@ -46,7 +46,7 @@ class ResBlock(torch.nn.Module):
         self,
         channels: int = 512,
         kernel_size: int = 3,
-        dilations: tp.List[int] = [1, 3, 5],
+        dilations: List[int] = [1, 3, 5],
     ):
         super(ResBlock, self).__init__()
         self.convs1 = nn.ModuleList()
@@ -234,13 +234,13 @@ class HiFTGenerator(nn.Module):
             nsf_alpha: float = 0.1,
             nsf_sigma: float = 0.003,
             nsf_voiced_threshold: float = 10,
-            upsample_rates: tp.List[int] = [8, 8],
-            upsample_kernel_sizes: tp.List[int] = [16, 16],
-            istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
-            resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
-            resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-            source_resblock_kernel_sizes: tp.List[int] = [7, 11],
-            source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
+            upsample_rates: List[int] = [8, 8],
+            upsample_kernel_sizes: List[int] = [16, 16],
+            istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+            resblock_kernel_sizes: List[int] = [3, 7, 11],
+            resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+            source_resblock_kernel_sizes: List[int] = [7, 11],
+            source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
             lrelu_slope: float = 0.1,
             audio_limit: float = 0.99,
             f0_predictor: torch.nn.Module = None,
@@ -316,11 +316,19 @@ class HiFTGenerator(nn.Module):
         self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
         self.f0_predictor = f0_predictor
 
-    def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
-        f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
-
-        har_source, _, _ = self.m_source(f0)
-        return har_source.transpose(1, 2)
+    def remove_weight_norm(self):
+        print('Removing weight norm...')
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+        self.m_source.remove_weight_norm()
+        for l in self.source_downs:
+            remove_weight_norm(l)
+        for l in self.source_resblocks:
+            l.remove_weight_norm()
 
     def _stft(self, x):
         spec = torch.stft(
@@ -338,14 +346,7 @@ class HiFTGenerator(nn.Module):
                                         self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
         return inverse_transform
 
-    def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
-        f0 = self.f0_predictor(x)
-        s = self._f02source(f0)
-
-        # use cache_source to avoid glitch
-        if cache_source.shape[2] != 0:
-            s[:, :, :cache_source.shape[2]] = cache_source
-
+    def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
         s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
         s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
 
@@ -377,22 +378,34 @@ class HiFTGenerator(nn.Module):
 
         x = self._istft(magnitude, phase)
         x = torch.clamp(x, -self.audio_limit, self.audio_limit)
-        return x, s
+        return x
 
-    def remove_weight_norm(self):
-        print('Removing weight norm...')
-        for l in self.ups:
-            remove_weight_norm(l)
-        for l in self.resblocks:
-            l.remove_weight_norm()
-        remove_weight_norm(self.conv_pre)
-        remove_weight_norm(self.conv_post)
-        self.source_module.remove_weight_norm()
-        for l in self.source_downs:
-            remove_weight_norm(l)
-        for l in self.source_resblocks:
-            l.remove_weight_norm()
+    def forward(
+            self,
+            batch: dict,
+            device: torch.device,
+    ) -> Dict[str, Optional[torch.Tensor]]:
+        speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
+        # mel->f0
+        f0 = self.f0_predictor(speech_feat)
+        # f0->source
+        s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
+        s, _, _ = self.m_source(s)
+        s = s.transpose(1, 2)
+        # mel+source->speech
+        generated_speech = self.decode(x=speech_feat, s=s)
+        return generated_speech, f0
 
     @torch.inference_mode()
-    def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
-        return self.forward(x=mel, cache_source=cache_source)
+    def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
+        # mel->f0
+        f0 = self.f0_predictor(speech_feat)
+        # f0->source
+        s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
+        s, _, _ = self.m_source(s)
+        s = s.transpose(1, 2)
+        # use cache_source to avoid glitch
+        if cache_source.shape[2] != 0:
+            s[:, :, :cache_source.shape[2]] = cache_source
+        generated_speech = self.decode(x=speech_feat, s=s)
+        return generated_speech, s

+ 69 - 0
cosyvoice/hifigan/hifigan.py

@@ -0,0 +1,69 @@
+from typing import Dict, Optional
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
+from cosyvoice.utils.losses import tpr_loss, mel_loss
+
+
+class HiFiGan(nn.Module):
+    def __init__(self, generator, discriminator, mel_spec_transform,
+                 multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
+                 tpr_loss_weight=1.0, tpr_loss_tau=0.04):
+        super(HiFiGan, self).__init__()
+        self.generator = generator
+        self.discriminator = discriminator
+        self.mel_spec_transform = mel_spec_transform
+        self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
+        self.feat_match_loss_weight = feat_match_loss_weight
+        self.tpr_loss_weight = tpr_loss_weight
+        self.tpr_loss_tau = tpr_loss_tau
+
+    def forward(
+            self,
+            batch: dict,
+            device: torch.device,
+    ) -> Dict[str, Optional[torch.Tensor]]:
+        if batch['turn'] == 'generator':
+            return self.forward_generator(batch, device)
+        else:
+            return self.forward_discriminator(batch, device)
+
+    def forward_generator(self, batch, device):
+        real_speech = batch['speech'].to(device)
+        pitch_feat = batch['pitch_feat'].to(device)
+        # 1. calculate generator outputs
+        generated_speech, generated_f0 = self.generator(batch, device)
+        # 2. calculate discriminator outputs
+        y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
+        # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
+        loss_gen, _ = generator_loss(y_d_gs)
+        loss_fm = feature_loss(fmap_rs, fmap_gs)
+        loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
+        if self.tpr_loss_weight != 0:
+            loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
+        else:
+            loss_tpr = torch.zeros(1).to(device)
+        loss_f0 = F.l1_loss(generated_f0, pitch_feat)
+        loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
+            self.multi_mel_spectral_recon_loss_weight * loss_mel + \
+            self.tpr_loss_weight * loss_tpr + loss_f0
+        return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
+
+    def forward_discriminator(self, batch, device):
+        real_speech = batch['speech'].to(device)
+        pitch_feat = batch['pitch_feat'].to(device)
+        # 1. calculate generator outputs
+        with torch.no_grad():
+            generated_speech, generated_f0 = self.generator(batch, device)
+        # 2. calculate discriminator outputs
+        y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
+        # 3. calculate discriminator losses, tpr losses [Optional]
+        loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
+        if self.tpr_loss_weight != 0:
+            loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
+        else:
+            loss_tpr = torch.zeros(1).to(device)
+        loss_f0 = F.l1_loss(generated_f0, pitch_feat)
+        loss = loss_disc + self.tpr_loss_weight * loss_tpr + loss_f0
+        return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}

+ 62 - 1
cosyvoice/utils/executor.py

@@ -25,7 +25,8 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
 
 class Executor:
 
-    def __init__(self):
+    def __init__(self, gan: bool = False):
+        self.gan = gan
         self.step = 0
         self.epoch = 0
         self.rank = int(os.environ.get('RANK', 0))
@@ -80,6 +81,64 @@ class Executor:
         dist.barrier()
         self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
 
+    def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
+                           writer, info_dict, group_join):
+        ''' Train one epoch
+        '''
+
+        lr = optimizer.param_groups[0]['lr']
+        logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
+        logging.info('using accumulate grad, new batch size is {} times'
+                     ' larger than before'.format(info_dict['accum_grad']))
+        # A context manager to be used in conjunction with an instance of
+        # torch.nn.parallel.DistributedDataParallel to be able to train
+        # with uneven inputs across participating processes.
+        model.train()
+        model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
+        with model_context():
+            for batch_idx, batch_dict in enumerate(train_data_loader):
+                info_dict["tag"] = "TRAIN"
+                info_dict["step"] = self.step
+                info_dict["epoch"] = self.epoch
+                info_dict["batch_idx"] = batch_idx
+                if cosyvoice_join(group_join, info_dict):
+                    break
+
+                # Disable gradient synchronizations across DDP processes.
+                # Within this context, gradients will be accumulated on module
+                # variables, which will later be synchronized.
+                if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
+                    context = model.no_sync
+                # Used for single gpu training and DDP gradient synchronization
+                # processes.
+                else:
+                    context = nullcontext
+
+                with context():
+                    batch_dict['turn'] = 'discriminator'
+                    info_dict = batch_forward(model, batch_dict, info_dict)
+                    info_dict = batch_backward(model, info_dict)
+                info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, info_dict)
+                optimizer.zero_grad()
+                log_per_step(writer, info_dict)
+                with context():
+                    batch_dict['turn'] = 'generator'
+                    info_dict = batch_forward(model, batch_dict, info_dict)
+                    info_dict = batch_backward(model, info_dict)
+                info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
+                optimizer_d.zero_grad()
+                log_per_step(writer, info_dict)
+                # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
+                if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
+                   (batch_idx + 1) % info_dict["accum_grad"] == 0:
+                    dist.barrier()
+                    self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
+                    model.train()
+                if (batch_idx + 1) % info_dict["accum_grad"] == 0:
+                    self.step += 1
+        dist.barrier()
+        self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
+
     @torch.inference_mode()
     def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
         ''' Cross validation on
@@ -96,6 +155,8 @@ class Executor:
             num_utts = len(batch_dict["utts"])
             total_num_utts += num_utts
 
+            if self.gan is True:
+                batch_dict['turn'] = 'generator'
             info_dict = batch_forward(model, batch_dict, info_dict)
 
             for k, v in info_dict['loss_dict'].items():

+ 20 - 0
cosyvoice/utils/losses.py

@@ -0,0 +1,20 @@
+import torch
+import torch.nn.functional as F
+
+
+def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
+    loss = 0
+    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+        m_DG = torch.median((dr - dg))
+        L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
+        loss += tau - F.relu(tau - L_rel)
+    return loss
+
+
+def mel_loss(real_speech, generated_speech, mel_transforms):
+    loss = 0
+    for transform in mel_transforms:
+        mel_r = transform(real_speech)
+        mel_g = transform(generated_speech)
+        loss += F.l1_loss(mel_g, mel_r)
+    return loss

+ 41 - 17
cosyvoice/utils/train_utils.py

@@ -51,9 +51,10 @@ def init_distributed(args):
     return world_size, local_rank, rank
 
 
-def init_dataset_and_dataloader(args, configs):
-    train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
-    cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
+def init_dataset_and_dataloader(args, configs, gan):
+    data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
+    train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
+    cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
 
     # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
     train_data_loader = DataLoader(train_dataset,
@@ -108,30 +109,31 @@ def wrap_cuda_model(args, model):
     return model
 
 
-def init_optimizer_and_scheduler(args, configs, model):
-    if configs['train_conf']['optim'] == 'adam':
-        optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
-    elif configs['train_conf']['optim'] == 'adamw':
-        optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
+def init_optimizer_and_scheduler(args, configs, model, gan):
+    key = 'train_conf_gan' if gan is True else 'train_conf'
+    if configs[key]['optim'] == 'adam':
+        optimizer = optim.Adam(model.parameters(), **configs[key]['optim_conf'])
+    elif configs[key]['optim'] == 'adamw':
+        optimizer = optim.AdamW(model.parameters(), **configs[key]['optim_conf'])
     else:
-        raise ValueError("unknown optimizer: " + configs['train_conf'])
+        raise ValueError("unknown optimizer: " + configs[key])
 
-    if configs['train_conf']['scheduler'] == 'warmuplr':
+    if configs[key]['scheduler'] == 'warmuplr':
         scheduler_type = WarmupLR
-        scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
-    elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
+        scheduler = WarmupLR(optimizer, **configs[key]['scheduler_conf'])
+    elif configs[key]['scheduler'] == 'NoamHoldAnnealing':
         scheduler_type = NoamHoldAnnealing
-        scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
-    elif configs['train_conf']['scheduler'] == 'constantlr':
+        scheduler = NoamHoldAnnealing(optimizer, **configs[key]['scheduler_conf'])
+    elif configs[key]['scheduler'] == 'constantlr':
         scheduler_type = ConstantLR
         scheduler = ConstantLR(optimizer)
     else:
-        raise ValueError("unknown scheduler: " + configs['train_conf'])
+        raise ValueError("unknown scheduler: " + configs[key])
 
     # use deepspeed optimizer for speedup
     if args.train_engine == "deepspeed":
         def scheduler(opt):
-            return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
+            return scheduler_type(opt, **configs[key]['scheduler_conf'])
         model, optimizer, _, scheduler = deepspeed.initialize(
             args=args,
             model=model,
@@ -139,7 +141,29 @@ def init_optimizer_and_scheduler(args, configs, model):
             lr_scheduler=scheduler,
             model_parameters=model.parameters())
 
-    return model, optimizer, scheduler
+    # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
+    if gan is True:
+        if configs[key]['optim_d'] == 'adam':
+            optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs[key]['optim_conf'])
+        elif configs[key]['optim_d'] == 'adamw':
+            optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs[key]['optim_conf'])
+        else:
+            raise ValueError("unknown optimizer: " + configs[key])
+
+        if configs[key]['scheduler_d'] == 'warmuplr':
+            scheduler_type = WarmupLR
+            scheduler_d = WarmupLR(optimizer_d, **configs[key]['scheduler_conf'])
+        elif configs[key]['scheduler_d'] == 'NoamHoldAnnealing':
+            scheduler_type = NoamHoldAnnealing
+            scheduler_d = NoamHoldAnnealing(optimizer_d, **configs[key]['scheduler_conf'])
+        elif configs[key]['scheduler'] == 'constantlr':
+            scheduler_type = ConstantLR
+            scheduler_d = ConstantLR(optimizer_d)
+        else:
+            raise ValueError("unknown scheduler: " + configs[key])
+    else:
+        optimizer_d, scheduler_d = None, None
+    return model, optimizer, scheduler, optimizer_d, scheduler_d
 
 
 def init_summarywriter(args):

+ 58 - 1
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -133,6 +133,25 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
         in_channels: 80
         cond_channels: 512
 
+# gan related module
+mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
+    generator: !ref <hift>
+    discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
+        mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
+    mel_spec_transform: [
+        !ref <mel_spec_transform1>
+    ]
+
 # processor functions
 parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
 get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
@@ -151,6 +170,8 @@ filter: !name:cosyvoice.dataset.processor.filter
     token_min_length: 1
 resample: !name:cosyvoice.dataset.processor.resample
     resample_rate: !ref <sample_rate>
+truncate: !name:cosyvoice.dataset.processor.truncate
+    truncate_length: 24576 # must be a multiplier of hop_size
 feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     n_fft: 1024
     num_mels: 80
@@ -162,6 +183,12 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     center: False
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
+pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch
+    sample_rate: !ref <sample_rate>
+    frame_length: 46.4 # match feat_extractor win_size/sampling_rate
+    frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate
+compute_f0: !name:cosyvoice.dataset.processor.compute_f0
+    pitch_extractor: !ref <pitch_extractor>
 parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
     normalize: True
 shuffle: !name:cosyvoice.dataset.processor.shuffle
@@ -187,8 +214,22 @@ data_pipeline: [
     !ref <batch>,
     !ref <padding>,
 ]
+data_pipeline_gan: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <truncate>,
+    !ref <compute_fbank>,
+    !ref <compute_f0>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
 
-# train conf
+# llm flow train conf
 train_conf:
     optim: adam
     optim_conf:
@@ -200,4 +241,20 @@ train_conf:
     grad_clip: 5
     accum_grad: 2
     log_interval: 100
+    save_per_step: -1
+
+# gan train conf
+train_conf_gan:
+    optim: adam
+    optim_conf:
+        lr: 0.0002 # use small lr for gan training
+    scheduler: constantlr
+    optim_d: adam
+    optim_conf_d:
+        lr: 0.0002 # use small lr for gan training
+    scheduler_d: constantlr
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 1 # in gan training, accum_grad must be 1
+    log_interval: 100
     save_per_step: -1

+ 59 - 2
examples/libritts/cosyvoice/conf/cosyvoice.yaml

@@ -133,6 +133,25 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
         in_channels: 80
         cond_channels: 512
 
+# gan related module
+mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
+    generator: !ref <hift>
+    discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
+        mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
+    mel_spec_transform: [
+        !ref <mel_spec_transform1>
+    ]
+
 # processor functions
 parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
 get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
@@ -151,6 +170,8 @@ filter: !name:cosyvoice.dataset.processor.filter
     token_min_length: 1
 resample: !name:cosyvoice.dataset.processor.resample
     resample_rate: !ref <sample_rate>
+truncate: !name:cosyvoice.dataset.processor.truncate
+    truncate_length: 24576 # must be a multiplier of hop_size
 feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     n_fft: 1024
     num_mels: 80
@@ -162,6 +183,12 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     center: False
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
+pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch
+    sample_rate: !ref <sample_rate>
+    frame_length: 46.4 # match feat_extractor win_size/sampling_rate
+    frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate
+compute_f0: !name:cosyvoice.dataset.processor.compute_f0
+    pitch_extractor: !ref <pitch_extractor>
 parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
     normalize: True
 shuffle: !name:cosyvoice.dataset.processor.shuffle
@@ -170,7 +197,7 @@ sort: !name:cosyvoice.dataset.processor.sort
     sort_size: 500  # sort_size should be less than shuffle_size
 batch: !name:cosyvoice.dataset.processor.batch
     batch_type: 'dynamic'
-    max_frames_in_batch: 2000
+    max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
 padding: !name:cosyvoice.dataset.processor.padding
     use_spk_embedding: False # change to True during sft
 
@@ -187,8 +214,22 @@ data_pipeline: [
     !ref <batch>,
     !ref <padding>,
 ]
+data_pipeline_gan: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <truncate>,
+    !ref <compute_fbank>,
+    !ref <compute_f0>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
 
-# train conf
+# llm flow train conf
 train_conf:
     optim: adam
     optim_conf:
@@ -200,4 +241,20 @@ train_conf:
     grad_clip: 5
     accum_grad: 2
     log_interval: 100
+    save_per_step: -1
+
+# gan train conf
+train_conf_gan:
+    optim: adam
+    optim_conf:
+        lr: 0.0002 # use small lr for gan training
+    scheduler: constantlr
+    optim_d: adam
+    optim_conf_d:
+        lr: 0.0002 # use small lr for gan training
+    scheduler_d: constantlr
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 1 # in gan training, accum_grad must be 1
+    log_interval: 100
     save_per_step: -1

+ 16 - 2
examples/libritts/cosyvoice/run.sh

@@ -83,9 +83,9 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
   fi
   cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
   cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
-  for model in llm flow; do
+  for model in llm flow hifigan; do
     torchrun --nnodes=1 --nproc_per_node=$num_gpus \
-        --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
+        --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
       cosyvoice/bin/train.py \
       --train_engine $train_engine \
       --config conf/cosyvoice.yaml \
@@ -104,7 +104,21 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
   done
 fi
 
+# average model
+average_num=5
 if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+  for model in llm flow hifigan; do
+    decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
+    echo "do model average and final checkpoint is $decode_checkpoint"
+    python cosyvoice/bin/average_model.py \
+      --dst_model $decode_checkpoint \
+      --src_path `pwd`/exp/cosyvoice/$model/$train_engine  \
+      --num ${average_num} \
+      --val_best
+  done
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
   echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
   python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
   python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir