Browse Source

[debug] a better solution for mismatch of speech feat len and speech token len, refer to https://github.com/FunAudioLLM/CosyVoice/issues/1051

burkliu 1 year ago
parent
commit
65ad448714
2 changed files with 10 additions and 4 deletions
  1. 10 2
      cosyvoice/dataset/processor.py
  2. 0 2
      cosyvoice/flow/flow.py

+ 10 - 2
cosyvoice/dataset/processor.py

@@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
 
 def compute_fbank(data,
                   feat_extractor,
+                  token_mel_ratio=2,
                   mode='train'):
     """ Extract fbank
 
@@ -174,8 +175,15 @@ def compute_fbank(data,
         assert 'utt' in sample
         assert 'text_token' in sample
         waveform = sample['speech']
-        mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
-        sample['speech_feat'] = mat
+        feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
+
+        # padding with replicate mode (align to speech_token len * token_mel_ratio)
+        pad_len = sample["speech_token"].shape[0] * token_mel_ratio - feat.shape[0]
+        if pad_len > 0:
+            feat_to_pad = feat[-1:].repeat((pad_len, 1))
+            feat = torch.cat([feat, feat_to_pad], dim=0)
+
+        sample['speech_feat'] = feat
         yield sample
 
 

+ 0 - 2
cosyvoice/flow/flow.py

@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
 
         mask = (~make_pad_mask(feat_len)).to(h)
         # NOTE this is unnecessary, feat/h already same shape
-        feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
         loss, _ = self.decoder.compute_loss(
             feat.transpose(1, 2).contiguous(),
             mask.unsqueeze(1),
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         h = self.encoder_proj(h)
 
         # get conditions
-        feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
         conds = torch.zeros(feat.shape, device=token.device)
         for i, j in enumerate(feat_len):
             if random.random() < 0.5: