Quellcode durchsuchen

fix vocoder train

lyuxiang.lx vor 11 Monaten
Ursprung
Commit
a69b7e275d

+ 2 - 1
cosyvoice/cli/model.py

@@ -299,7 +299,8 @@ class CosyVoice2Model(CosyVoiceModel):
             self.flow.half()
         self.token_hop_len = self.flow.encoder.static_chunk_size
         # flow decoder required_cache_size
-        self.flow_decoder_required_cache_size = self.flow.decoder.estimator.num_decoding_left_chunks * self.flow.decoder.estimator.static_chunk_size
+        # TODO 基模型训练时没有设置num_decoding_left_chunks,需要重新训一下才能指定flow_decoder_required_cache_size
+        self.flow_decoder_required_cache_size = 999
         # hift cache
         self.mel_cache_len = 8
         self.source_cache_len = int(self.mel_cache_len * 480)

+ 2 - 2
cosyvoice/flow/flow.py

@@ -91,7 +91,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         conds = conds.transpose(1, 2)
 
         mask = (~make_pad_mask(feat_len)).to(h)
-        # NOTE 这一句应该是不需要的,应该h已经过length_regulator跟feat一样的shape了
+        # NOTE this is unnecessary, feat/h already same shape
         feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
         loss, _ = self.decoder.compute_loss(
             feat.transpose(1, 2).contiguous(),
@@ -117,7 +117,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         embedding = F.normalize(embedding, dim=1)
         embedding = self.spk_embed_affine_layer(embedding)
 
-        # concat text and prompt_text
+        # concat speech token and prompt speech token
         token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
         token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
         mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)

+ 1 - 0
cosyvoice/flow/length_regulator.py

@@ -51,6 +51,7 @@ class InterpolateRegulator(nn.Module):
 
     def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
         # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
+        # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
         # x in (B, T, D)
         if x2.shape[1] > 40:
             x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')

+ 89 - 2
cosyvoice/hifigan/discriminator.py

@@ -1,13 +1,16 @@
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 try:
-    from torch.nn.utils.parametrizations import weight_norm
+    from torch.nn.utils.parametrizations import weight_norm, spectral_norm
 except ImportError:
-    from torch.nn.utils import weight_norm
+    from torch.nn.utils import weight_norm, spectral_norm
 from typing import List, Optional, Tuple
 from einops import rearrange
 from torchaudio.transforms import Spectrogram
 
+LRELU_SLOPE = 0.1
+
 
 class MultipleDiscriminator(nn.Module):
     def __init__(
@@ -141,3 +144,87 @@ class DiscriminatorR(nn.Module):
         x += h
 
         return x, fmap
+
+
+class MultiResSpecDiscriminator(torch.nn.Module):
+
+    def __init__(self,
+                 fft_sizes=[1024, 2048, 512],
+                 hop_sizes=[120, 240, 50],
+                 win_lengths=[600, 1200, 240],
+                 window="hann_window"):
+
+        super(MultiResSpecDiscriminator, self).__init__()
+        self.discriminators = nn.ModuleList([
+            SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
+            SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
+            SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+    """Perform STFT and convert to magnitude spectrogram.
+    Args:
+        x (Tensor): Input signal tensor (B, T).
+        fft_size (int): FFT size.
+        hop_size (int): Hop size.
+        win_length (int): Window length.
+        window (str): Window function type.
+    Returns:
+        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+    """
+    x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
+
+    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
+    return torch.abs(x_stft).transpose(2, 1)
+
+
+class SpecDiscriminator(nn.Module):
+    """docstring for Discriminator."""
+
+    def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
+        super(SpecDiscriminator, self).__init__()
+        norm_f = weight_norm if use_spectral_norm is False else spectral_norm
+        self.fft_size = fft_size
+        self.shift_size = shift_size
+        self.win_length = win_length
+        self.window = getattr(torch, window)(win_length)
+        self.discriminators = nn.ModuleList([
+            norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
+            norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
+            norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
+            norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
+            norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
+        ])
+
+        self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
+
+    def forward(self, y):
+
+        fmap = []
+        y = y.squeeze(1)
+        y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
+        y = y.unsqueeze(1)
+        for i, d in enumerate(self.discriminators):
+            y = d(y)
+            y = F.leaky_relu(y, LRELU_SLOPE)
+            fmap.append(y)
+
+        y = self.out(y)
+        fmap.append(y)
+
+        return torch.flatten(y, 1, -1), fmap

+ 1 - 1
cosyvoice/hifigan/hifigan.py

@@ -56,7 +56,7 @@ class HiFiGan(nn.Module):
         with torch.no_grad():
             generated_speech, generated_f0 = self.generator(batch, device)
         # 2. calculate discriminator outputs
-        y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
+        y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
         # 3. calculate discriminator losses, tpr losses [Optional]
         loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
         if self.tpr_loss_weight != 0:

+ 2 - 1
cosyvoice/llm/llm.py

@@ -326,7 +326,8 @@ class Qwen2LM(TransformerLM):
             # unistream sequence
             else:
                 this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
-                this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i], self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
+                this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
+                                              self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
             lm_target.append(this_lm_target)
             lm_input.append(this_lm_input)
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)

+ 1 - 1
cosyvoice/utils/train_utils.py

@@ -340,7 +340,7 @@ def log_per_save(writer, info_dict):
     rank = int(os.environ.get('RANK', 0))
     logging.info(
         'Epoch {} Step {} CV info lr {} {} rank {}'.format(
-            epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
+            epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
 
     if writer is not None:
         for k in ['epoch', 'lr']:

+ 1 - 1
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
     generator: !ref <hift>
     discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
         mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
-        mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
     mel_spec_transform: [
         !ref <mel_spec_transform1>
     ]

+ 1 - 1
examples/libritts/cosyvoice/conf/cosyvoice.yaml

@@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
     generator: !ref <hift>
     discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
         mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
-        mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
     mel_spec_transform: [
         !ref <mel_spec_transform1>
     ]

+ 6 - 6
examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

@@ -14,8 +14,8 @@ token_frame_rate: 25
 token_mel_ratio: 2
 
 # stream related params
-chunk_size: 1 # streaming inference chunk size, in second
-num_decoding_left_chunks: 2 # streaming inference flow decoder left chunk size, in second
+chunk_size: 2 # streaming inference chunk size, in second
+num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size
 
 # model params
 # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
@@ -112,11 +112,11 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
 
 # gan related module
 mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
-    n_fft: 1024
+    n_fft: 1920
     num_mels: 80
     sampling_rate: !ref <sample_rate>
-    hop_size: 256
-    win_size: 1024
+    hop_size: 480
+    win_size: 1920
     fmin: 0
     fmax: null
     center: False
@@ -124,7 +124,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
     generator: !ref <hift>
     discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
         mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
-        mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
     mel_spec_transform: [
         !ref <mel_spec_transform1>
     ]

+ 1 - 1
examples/libritts/cosyvoice2/run.sh

@@ -71,7 +71,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
 fi
 
 # train llm
-export CUDA_VISIBLE_DEVICES="2,3,4,5,6,7"
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
 num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
 job_id=1986
 dist_backend="nccl"

+ 1 - 0
requirements.txt

@@ -21,6 +21,7 @@ onnxruntime-gpu==1.18.0; sys_platform == 'linux'
 onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'windows'
 openai-whisper==20231117
 protobuf==4.25
+pyarrow==18.1.0
 pydantic==2.7.0
 pyworld==0.3.4
 rich==13.7.1