Sfoglia il codice sorgente

use wav file rather than tensor

lyuxiang.lx 4 mesi fa
parent
commit
622a3a19b0
3 ha cambiato i file con 31 aggiunte e 30 eliminazioni
  1. 10 10
      cosyvoice/cli/cosyvoice.py
  2. 19 18
      cosyvoice/cli/frontend.py
  3. 2 2
      cosyvoice/utils/file_utils.py

+ 10 - 10
cosyvoice/cli/cosyvoice.py

@@ -67,9 +67,9 @@ class CosyVoice:
         spks = list(self.frontend.spk2info.keys())
         return spks
 
-    def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
+    def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
         assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
-        model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
+        model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
         del model_input['text']
         del model_input['text_len']
         self.frontend.spk2info[zero_shot_spk_id] = model_input
@@ -89,12 +89,12 @@ class CosyVoice:
                 yield model_output
                 start_time = time.time()
 
-    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
+    def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
         prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
             if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
                 logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
-            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -103,9 +103,9 @@ class CosyVoice:
                 yield model_output
                 start_time = time.time()
 
-    def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
+    def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
-            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+            model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -129,8 +129,8 @@ class CosyVoice:
                 yield model_output
                 start_time = time.time()
 
-    def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
-        model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
+    def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
+        model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
         start_time = time.time()
         for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
             speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
@@ -181,10 +181,10 @@ class CosyVoice2(CosyVoice):
     def inference_instruct(self, *args, **kwargs):
         raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
 
-    def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
+    def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
         assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
-            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):

+ 19 - 18
cosyvoice/cli/frontend.py

@@ -32,7 +32,7 @@ except ImportError:
     from wetext import Normalizer as ZhNormalizer
     from wetext import Normalizer as EnNormalizer
     use_ttsfrd = False
-from cosyvoice.utils.file_utils import logging
+from cosyvoice.utils.file_utils import logging, load_wav
 from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
 
 
@@ -89,7 +89,8 @@ class CosyVoiceFrontEnd:
             for i in range(text_token.shape[1]):
                 yield text_token[:, i: i + 1]
 
-    def _extract_speech_token(self, speech):
+    def _extract_speech_token(self, prompt_wav):
+        speech = load_wav(prompt_wav, 16000)
         assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
         feat = whisper.log_mel_spectrogram(speech, n_mels=128)
         speech_token = self.speech_tokenizer_session.run(None,
@@ -101,7 +102,8 @@ class CosyVoiceFrontEnd:
         speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
         return speech_token, speech_token_len
 
-    def _extract_spk_embedding(self, speech):
+    def _extract_spk_embedding(self, prompt_wav):
+        speech = load_wav(prompt_wav, 16000)
         feat = kaldi.fbank(speech,
                            num_mel_bins=80,
                            dither=0,
@@ -112,7 +114,8 @@ class CosyVoiceFrontEnd:
         embedding = torch.tensor([embedding]).to(self.device)
         return embedding
 
-    def _extract_speech_feat(self, speech):
+    def _extract_speech_feat(self, prompt_wav):
+        speech = load_wav(prompt_wav, 24000)
         speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
         speech_feat = speech_feat.unsqueeze(dim=0)
         speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
@@ -154,19 +157,18 @@ class CosyVoiceFrontEnd:
         model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
         return model_input
 
-    def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
+    def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
         tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
         if zero_shot_spk_id == '':
             prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
-            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
-            speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
-            speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
+            speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
+            speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
             if resample_rate == 24000:
                 # cosyvoice2, force speech_feat % speech_token = 2
                 token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
                 speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
                 speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
-            embedding = self._extract_spk_embedding(prompt_speech_16k)
+            embedding = self._extract_spk_embedding(prompt_wav)
             model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
                            'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
                            'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
@@ -178,8 +180,8 @@ class CosyVoiceFrontEnd:
         model_input['text_len'] = tts_text_token_len
         return model_input
 
-    def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
-        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
+    def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
+        model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
         # in cross lingual mode, we remove prompt in llm
         del model_input['prompt_text']
         del model_input['prompt_text_len']
@@ -196,17 +198,16 @@ class CosyVoiceFrontEnd:
         model_input['prompt_text_len'] = instruct_text_token_len
         return model_input
 
-    def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
-        model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
+    def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
+        model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_wav, resample_rate, zero_shot_spk_id)
         del model_input['llm_prompt_speech_token']
         del model_input['llm_prompt_speech_token_len']
         return model_input
 
-    def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
-        prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
-        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
-        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
-        embedding = self._extract_spk_embedding(prompt_speech_16k)
+    def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
+        prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
+        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
+        embedding = self._extract_spk_embedding(prompt_wav)
         source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
         model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
                        'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,

+ 2 - 2
cosyvoice/utils/file_utils.py

@@ -41,11 +41,11 @@ def read_json_lists(list_file):
     return results
 
 
-def load_wav(wav, target_sr):
+def load_wav(wav, target_sr, min_sr=16000):
     speech, sample_rate = torchaudio.load(wav, backend='soundfile')
     speech = speech.mean(dim=0, keepdim=True)
     if sample_rate != target_sr:
-        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
+        assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
         speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
     return speech