Forráskód Böngészése

Merge pull request #403 from FunAudioLLM/dev/lyuxiang.lx

update tempo change
Xiang Lyu 1 éve
szülő
commit
2898d5a851
4 módosított fájl, 23 hozzáadás és 31 törlés
  1. 8 8
      cosyvoice/cli/cosyvoice.py
  2. 8 3
      cosyvoice/cli/model.py
  3. 0 13
      cosyvoice/utils/file_utils.py
  4. 7 7
      webui.py

+ 8 - 8
cosyvoice/cli/cosyvoice.py

@@ -53,43 +53,43 @@ class CosyVoice:
         spks = list(self.frontend.spk2info.keys())
         return spks
 
-    def inference_sft(self, tts_text, spk_id, stream=False):
+    def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
             model_input = self.frontend.frontend_sft(i, spk_id)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
-            for model_output in self.model.inference(**model_input, stream=stream):
+            for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
 
-    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
+    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
         prompt_text = self.frontend.text_normalize(prompt_text, split=False)
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
             model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
-            for model_output in self.model.inference(**model_input, stream=stream):
+            for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
 
-    def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
+    def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
         if self.frontend.instruct is True:
             raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
             model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
-            for model_output in self.model.inference(**model_input, stream=stream):
+            for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
 
-    def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
+    def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
         if self.frontend.instruct is False:
             raise ValueError('{} do not support instruct inference'.format(self.model_dir))
         instruct_text = self.frontend.text_normalize(instruct_text, split=False)
@@ -97,7 +97,7 @@ class CosyVoice:
             model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
-            for model_output in self.model.inference(**model_input, stream=stream):
+            for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output

+ 8 - 3
cosyvoice/cli/model.py

@@ -15,6 +15,7 @@ import torch
 import numpy as np
 import threading
 import time
+from torch.nn import functional as F
 from contextlib import nullcontext
 import uuid
 from cosyvoice.utils.common import fade_in_out
@@ -91,7 +92,7 @@ class CosyVoiceModel:
                 self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
-    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
         tts_mel = self.flow.inference(token=token.to(self.device),
                                       token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                       prompt_token=prompt_token.to(self.device),
@@ -116,6 +117,9 @@ class CosyVoiceModel:
             self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
             tts_speech = tts_speech[:, :-self.source_cache_len]
         else:
+            if speed != 1.0:
+                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
+                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
             tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
         return tts_speech
 
@@ -123,7 +127,7 @@ class CosyVoiceModel:
                   prompt_text=torch.zeros(1, 0, dtype=torch.int32),
                   llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
                   flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
-                  prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
+                  prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
         with self.lock:
@@ -169,7 +173,8 @@ class CosyVoiceModel:
                                              prompt_feat=prompt_speech_feat,
                                              embedding=flow_embedding,
                                              uuid=this_uuid,
-                                             finalize=True)
+                                             finalize=True,
+                                             speed=speed)
             yield {'tts_speech': this_tts_speech.cpu()}
         with self.lock:
             self.tts_speech_token_dict.pop(this_uuid)

+ 0 - 13
cosyvoice/utils/file_utils.py

@@ -45,16 +45,3 @@ def load_wav(wav, target_sr):
         assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
         speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
     return speech
-
-
-def speed_change(waveform, sample_rate, speed_factor: str):
-    effects = [
-        ["tempo", speed_factor],  # speed_factor
-        ["rate", f"{sample_rate}"]
-    ]
-    augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor(
-        waveform,
-        sample_rate,
-        effects
-    )
-    return augmented_waveform, new_sample_rate

+ 7 - 7
webui.py

@@ -66,7 +66,7 @@ def change_instruction(mode_checkbox_group):
 
 
 def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
-                   seed, stream, speed_factor):
+                   seed, stream, speed):
     if prompt_wav_upload is not None:
         prompt_wav = prompt_wav_upload
     elif prompt_wav_record is not None:
@@ -117,24 +117,24 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
     if mode_checkbox_group == '预训练音色':
         logging.info('get sft inference request')
         set_all_random_seed(seed)
-        for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
+        for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
             yield (target_sr, i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '3s极速复刻':
         logging.info('get zero_shot inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
-        for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
+        for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
             yield (target_sr, i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '跨语种复刻':
         logging.info('get cross_lingual inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
-        for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
+        for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
             yield (target_sr, i['tts_speech'].numpy().flatten())
     else:
         logging.info('get instruct inference request')
         set_all_random_seed(seed)
-        for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
+        for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
             yield (target_sr, i['tts_speech'].numpy().flatten())
 
 
@@ -147,12 +147,12 @@ def main():
         gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
 
         tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
-        speed_factor = gr.Slider(minimum=0.25, maximum=4, step=0.05, label="语速调节", value=1.0, interactive=True)
         with gr.Row():
             mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
             instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
             sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
             stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
+            speed = gr.Number(value=1, label="速度调节(仅支持非流式推理)", minimum=0.5, maximum=2.0, step=0.1)
             with gr.Column(scale=0.25):
                 seed_button = gr.Button(value="\U0001F3B2")
                 seed = gr.Number(value=0, label="随机推理种子")
@@ -170,7 +170,7 @@ def main():
         seed_button.click(generate_seed, inputs=[], outputs=seed)
         generate_button.click(generate_audio,
                               inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
-                                      seed, stream, speed_factor],
+                                      seed, stream, speed],
                               outputs=[audio_output])
         mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
     demo.queue(max_size=4, default_concurrency_limit=2)