瀏覽代碼

update hifigan

lyuxiang.lx 1 年之前
父節點
當前提交
73784974ce
共有 3 個文件被更改,包括 64 次插入5 次删除
  1. 6 2
      cosyvoice/bin/train.py
  2. 0 2
      cosyvoice/dataset/processor.py
  3. 58 1
      examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

+ 6 - 2
cosyvoice/bin/train.py

@@ -18,6 +18,7 @@ import datetime
 import logging
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 from copy import deepcopy
+import os
 import torch
 import torch.distributed as dist
 import deepspeed
@@ -112,7 +113,10 @@ def main():
     # load checkpoint
     model = configs[args.model]
     if args.checkpoint is not None:
-        model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
+        if os.path.exists(args.checkpoint):
+            model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
+        else:
+            logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
 
     # Dispatch model from cpu to gpu
     model = wrap_cuda_model(args, model)
@@ -125,7 +129,7 @@ def main():
     save_model(model, 'init', info_dict)
 
     # Get executor
-    executor = Executor()
+    executor = Executor(gan=gan)
 
     # Start training loop
     for epoch in range(info_dict['max_epoch']):

+ 0 - 2
cosyvoice/dataset/processor.py

@@ -393,8 +393,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
             "speech_token_len": speech_token_len,
             "speech_feat": speech_feat,
             "speech_feat_len": speech_feat_len,
-            "pitch_feat": pitch_feat,
-            "pitch_feat_len": pitch_feat_len,
             "text": text,
             "text_token": text_token,
             "text_token_len": text_token_len,

+ 58 - 1
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -133,6 +133,25 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
         in_channels: 80
         cond_channels: 512
 
+# gan related module
+mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
+    generator: !ref <hift>
+    discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
+        mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
+    mel_spec_transform: [
+        !ref <mel_spec_transform1>
+    ]
+
 # processor functions
 parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
 get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
@@ -151,6 +170,8 @@ filter: !name:cosyvoice.dataset.processor.filter
     token_min_length: 1
 resample: !name:cosyvoice.dataset.processor.resample
     resample_rate: !ref <sample_rate>
+truncate: !name:cosyvoice.dataset.processor.truncate
+    truncate_length: 24576 # must be a multiplier of hop_size
 feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     n_fft: 1024
     num_mels: 80
@@ -162,6 +183,12 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     center: False
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
+pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch
+    sample_rate: !ref <sample_rate>
+    frame_length: 46.4 # match feat_extractor win_size/sampling_rate
+    frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate
+compute_f0: !name:cosyvoice.dataset.processor.compute_f0
+    pitch_extractor: !ref <pitch_extractor>
 parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
     normalize: True
 shuffle: !name:cosyvoice.dataset.processor.shuffle
@@ -187,8 +214,22 @@ data_pipeline: [
     !ref <batch>,
     !ref <padding>,
 ]
+data_pipeline_gan: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <truncate>,
+    !ref <compute_fbank>,
+    !ref <compute_f0>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
 
-# train conf
+# llm flow train conf
 train_conf:
     optim: adam
     optim_conf:
@@ -200,4 +241,20 @@ train_conf:
     grad_clip: 5
     accum_grad: 2
     log_interval: 100
+    save_per_step: -1
+
+# gan train conf
+train_conf_gan:
+    optim: adam
+    optim_conf:
+        lr: 0.0002 # use small lr for gan training
+    scheduler: constantlr
+    optim_d: adam
+    optim_conf_d:
+        lr: 0.0002 # use small lr for gan training
+    scheduler_d: constantlr
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 1 # in gan training, accum_grad must be 1
+    log_interval: 100
     save_per_step: -1