1
0
lyuxiang.lx преди 8 месеца
родител
ревизия
190840b8dc

+ 9 - 4
cosyvoice/dataset/processor.py

@@ -20,6 +20,7 @@ import torch
 import torchaudio
 from torch.nn.utils.rnn import pad_sequence
 import torch.nn.functional as F
+import pyworld as pw
 
 
 AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
@@ -178,7 +179,7 @@ def compute_fbank(data,
         yield sample
 
 
-def compute_f0(data, pitch_extractor, mode='train'):
+def compute_f0(data, sample_rate, hop_size, mode='train'):
     """ Extract f0
 
         Args:
@@ -187,15 +188,19 @@ def compute_f0(data, pitch_extractor, mode='train'):
         Returns:
             Iterable[{key, feat, label}]
     """
+    frame_period = hop_size * 1000 / sample_rate
     for sample in data:
         assert 'sample_rate' in sample
         assert 'speech' in sample
         assert 'utt' in sample
         assert 'text_token' in sample
         waveform = sample['speech']
-        mat = pitch_extractor(waveform).transpose(1, 2)
-        mat = F.interpolate(mat, size=sample['speech_feat'].shape[0], mode='linear')
-        sample['pitch_feat'] = mat[0, 0]
+        _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
+        if sum(_f0 != 0) < 5: # this happens when the algorithm fails
+            _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
+        f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
+        f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
+        sample['pitch_feat'] = f0
         yield sample
 
 

+ 5 - 0
cosyvoice/utils/mask.py

@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import torch
+from cosyvoice.utils.file_utils import logging
 '''
 def subsequent_mask(
         size: int,
@@ -230,6 +231,10 @@ def add_optional_chunk_mask(xs: torch.Tensor,
         chunk_masks = masks & chunk_masks  # (B, L, L)
     else:
         chunk_masks = masks
+    assert chunk_masks.dtype == torch.bool
+    if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
+        logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
+        chunk_masks[chunk_masks.sum(dim=-1)==0] = True
     return chunk_masks
 
 

+ 2 - 5
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -183,12 +183,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     center: False
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
-pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch
-    sample_rate: !ref <sample_rate>
-    frame_length: 46.4 # match feat_extractor win_size/sampling_rate
-    frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate
 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
-    pitch_extractor: !ref <pitch_extractor>
+    sample_rate: !ref <sample_rate>
+    hop_size: 256
 parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
     normalize: True
 shuffle: !name:cosyvoice.dataset.processor.shuffle

+ 2 - 5
examples/libritts/cosyvoice/conf/cosyvoice.yaml

@@ -183,12 +183,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     center: False
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
-pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch
-    sample_rate: !ref <sample_rate>
-    frame_length: 46.4 # match feat_extractor win_size/sampling_rate
-    frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate
 compute_f0: !name:cosyvoice.dataset.processor.compute_f0
-    pitch_extractor: !ref <pitch_extractor>
+    sample_rate: !ref <sample_rate>
+    hop_size: 256
 parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
     normalize: True
 shuffle: !name:cosyvoice.dataset.processor.shuffle

+ 1 - 0
requirements.txt

@@ -22,6 +22,7 @@ onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'windows'
 openai-whisper==20231117
 protobuf==4.25
 pydantic==2.7.0
+pyworld==0.3.4
 rich==13.7.1
 soundfile==0.12.1
 tensorboard==2.14.0