|
|
@@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
|
|
|
|
|
|
def compute_fbank(data,
|
|
|
feat_extractor,
|
|
|
+ token_mel_ratio=2,
|
|
|
mode='train'):
|
|
|
""" Extract fbank
|
|
|
|
|
|
@@ -174,8 +175,15 @@ def compute_fbank(data,
|
|
|
assert 'utt' in sample
|
|
|
assert 'text_token' in sample
|
|
|
waveform = sample['speech']
|
|
|
- mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
|
|
- sample['speech_feat'] = mat
|
|
|
+ feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
|
|
+
|
|
|
+ # padding with replicate mode (align to speech_token len * token_mel_ratio)
|
|
|
+ pad_len = sample["speech_token"].shape[0] * token_mel_ratio - feat.shape[0]
|
|
|
+ if pad_len > 0:
|
|
|
+ feat_to_pad = feat[-1:].repeat((pad_len, 1))
|
|
|
+ feat = torch.cat([feat, feat_to_pad], dim=0)
|
|
|
+
|
|
|
+ sample['speech_feat'] = feat
|
|
|
yield sample
|
|
|
|
|
|
|