Jelajahi Sumber

Merge pull request #719 from FunAudioLLM/dev/lyuxiang.lx

add instruct usage
Xiang Lyu 11 bulan lalu
induk
melakukan
c4688b68eb
4 mengubah file dengan 37 tambahan dan 2 penghapusan
  1. 3 0
      README.md
  2. 13 0
      cosyvoice/cli/cosyvoice.py
  3. 20 1
      cosyvoice/cli/frontend.py
  4. 1 1
      webui.py

+ 3 - 0
README.md

@@ -139,6 +139,9 @@ cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_
 prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
 for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
     torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+# instruct usage
+for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
 
 # cosyvoice
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)

+ 13 - 0
cosyvoice/cli/cosyvoice.py

@@ -98,6 +98,7 @@ class CosyVoice:
                 start_time = time.time()
 
     def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
+        assert isinstance(self.model, CosyVoiceModel)
         if self.frontend.instruct is False:
             raise ValueError('{} do not support instruct inference'.format(self.model_dir))
         instruct_text = self.frontend.text_normalize(instruct_text, split=False)
@@ -111,6 +112,18 @@ class CosyVoice:
                 yield model_output
                 start_time = time.time()
 
+    def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0):
+        assert isinstance(self.model, CosyVoice2Model)
+        for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
+            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
+            start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
+            for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()
+
     def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
         model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
         start_time = time.time()

+ 20 - 1
cosyvoice/cli/frontend.py

@@ -152,7 +152,7 @@ class CosyVoiceFrontEnd:
         if resample_rate == 24000:
             # cosyvoice2, force speech_feat % speech_token = 2
             token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
-            speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2* token_len
+            speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
             speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
         embedding = self._extract_spk_embedding(prompt_speech_16k)
         model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
@@ -181,6 +181,25 @@ class CosyVoiceFrontEnd:
         model_input['prompt_text_len'] = instruct_text_token_len
         return model_input
 
+    def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
+        tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+        prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
+        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
+        speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
+        speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
+        if resample_rate == 24000:
+            # cosyvoice2, force speech_feat % speech_token = 2
+            token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
+            speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
+            speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
+        embedding = self._extract_spk_embedding(prompt_speech_16k)
+        model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+                       'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
+                       'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+                       'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+                       'llm_embedding': embedding, 'flow_embedding': embedding}
+        return model_input
+
     def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
         prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
         prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)

+ 1 - 1
webui.py

@@ -144,7 +144,7 @@ def main():
         with gr.Row():
             mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
             instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
-            sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
+            sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0] if len(sft_spk) != 0 else '', scale=0.25)
             stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
             speed = gr.Number(value=1, label="速度调节(仅支持非流式推理)", minimum=0.5, maximum=2.0, step=0.1)
             with gr.Column(scale=0.25):