Sfoglia il codice sorgente

Merge pull request #1232 from boji123/bj_dev_feat_len_pad

a better solution for mismatch of speech feat len and speech token len when trainning
Xiang Lyu 7 mesi fa
parent
commit
88f467a8ac

+ 2 - 1
README.md

@@ -134,10 +134,11 @@ prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
 for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
     torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 
-# save zero_shot spk for futher usage
+# save zero_shot spk for future usage
 assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
 for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
     torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+cosyvoice.save_spkinfo()
 
 # fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
 for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):

+ 7 - 4
cosyvoice/cli/cosyvoice.py

@@ -74,6 +74,9 @@ class CosyVoice:
         self.frontend.spk2info[zero_shot_spk_id] = model_input
         return True
 
+    def save_spkinfo(self):
+        torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
+
     def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
             model_input = self.frontend.frontend_sft(i, spk_id)
@@ -99,9 +102,9 @@ class CosyVoice:
                 yield model_output
                 start_time = time.time()
 
-    def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
+    def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
-            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
+            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -174,10 +177,10 @@ class CosyVoice2(CosyVoice):
     def inference_instruct(self, *args, **kwargs):
         raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
 
-    def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
+    def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
         assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
-            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
+            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):

+ 4 - 4
cosyvoice/cli/frontend.py

@@ -178,8 +178,8 @@ class CosyVoiceFrontEnd:
         model_input['text_len'] = tts_text_token_len
         return model_input
 
-    def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
-        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
+    def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
+        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
         # in cross lingual mode, we remove prompt in llm
         del model_input['prompt_text']
         del model_input['prompt_text_len']
@@ -196,8 +196,8 @@ class CosyVoiceFrontEnd:
         model_input['prompt_text_len'] = instruct_text_token_len
         return model_input
 
-    def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
-        model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
+    def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
+        model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
         del model_input['llm_prompt_speech_token']
         del model_input['llm_prompt_speech_token_len']
         return model_input

+ 9 - 2
cosyvoice/dataset/processor.py

@@ -159,6 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
 
 def compute_fbank(data,
                   feat_extractor,
+                  token_mel_ratio=2,
                   mode='train'):
     """ Extract fbank
 
@@ -174,8 +175,14 @@ def compute_fbank(data,
         assert 'utt' in sample
         assert 'text_token' in sample
         waveform = sample['speech']
-        mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
-        sample['speech_feat'] = mat
+        feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
+
+        # trim to align speech_token and speech_feat
+        token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0])
+        feat = feat[:token_mel_ratio * token_len]
+        sample["speech_token"] = sample["speech_token"][:token_len]
+
+        sample['speech_feat'] = feat
         yield sample
 
 

+ 0 - 2
cosyvoice/flow/flow.py

@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
 
         mask = (~make_pad_mask(feat_len)).to(h)
         # NOTE this is unnecessary, feat/h already same shape
-        feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
         loss, _ = self.decoder.compute_loss(
             feat.transpose(1, 2).contiguous(),
             mask.unsqueeze(1),
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         h = self.encoder_proj(h)
 
         # get conditions
-        feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
         conds = torch.zeros(feat.shape, device=token.device)
         for i, j in enumerate(feat_len):
             if random.random() < 0.5:

+ 37 - 0
test1.py

@@ -0,0 +1,37 @@
+import sys
+sys.path.append('third_party/Matcha-TTS')
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
+from cosyvoice.utils.file_utils import load_wav
+import torchaudio # type: ignore
+
+cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
+
+# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
+# zero_shot usage
+prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
+for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+# save zero_shot spk for future usage
+assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
+for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+cosyvoice.save_spkinfo()
+
+# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
+for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
+    torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+# instruct usage
+for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
+    torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+# bistream usage, you can use generator as input, this is useful when using text llm model as input
+# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
+def text_generator():
+    yield '收到好友从远方寄来的生日礼物,'
+    yield '那份意外的惊喜与深深的祝福'
+    yield '让我心中充满了甜蜜的快乐,'
+    yield '笑容如花儿般绽放。'
+for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)