1
0
lyuxiang.lx 1 жил өмнө
parent
commit
02f941d348

+ 4 - 4
cosyvoice/cli/cosyvoice.py

@@ -46,9 +46,9 @@ class CosyVoice:
         return spks
 
     def inference_sft(self, tts_text, spk_id, stream=False):
-        start_time = time.time()
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_sft(i, spk_id)
+            start_time = time.time()
             for model_output in self.model.inference(**model_input, stream=stream):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -56,10 +56,10 @@ class CosyVoice:
                 start_time = time.time()
 
     def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
-        start_time = time.time()
         prompt_text = self.frontend.text_normalize(prompt_text, split=False)
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
+            start_time = time.time()
             for model_output in self.model.inference(**model_input, stream=stream):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -69,9 +69,9 @@ class CosyVoice:
     def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
         if self.frontend.instruct is True:
             raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
-        start_time = time.time()
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
+            start_time = time.time()
             for model_output in self.model.inference(**model_input, stream=stream):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -81,10 +81,10 @@ class CosyVoice:
     def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
         if self.frontend.instruct is False:
             raise ValueError('{} do not support instruct inference'.format(self.model_dir))
-        start_time = time.time()
         instruct_text = self.frontend.text_normalize(instruct_text, split=False)
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
+            start_time = time.time()
             for model_output in self.model.inference(**model_input, stream=stream):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))

+ 64 - 32
cosyvoice/cli/model.py

@@ -13,6 +13,9 @@
 # limitations under the License.
 import torch
 import numpy as np
+import threading
+import time
+from contextlib import nullcontext
 
 
 class CosyVoiceModel:
@@ -25,10 +28,13 @@ class CosyVoiceModel:
         self.llm = llm
         self.flow = flow
         self.hift = hift
-        self.stream_win_len = 60
-        self.stream_hop_len = 50
-        self.overlap = 4395 # 10 token equals 4395 sample point
+        self.stream_win_len = 60 * 4
+        self.stream_hop_len = 50 * 4
+        self.overlap = 4395 * 4 # 10 token equals 4395 sample point
         self.window = np.hamming(2 * self.overlap)
+        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.lock = threading.Lock()
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -38,13 +44,8 @@ class CosyVoiceModel:
         self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
         self.hift.to(self.device).eval()
 
-    def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
-                  prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
-                  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
-                  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
-                  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
-        if stream is True:
-            tts_speech_token, cache_speech = [], None
+    def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding):
+        with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),
                                                 text_len=text_len.to(self.device),
                                                 prompt_text=prompt_text.to(self.device),
@@ -56,10 +57,56 @@ class CosyVoiceModel:
                                                 sampling=25,
                                                 max_token_text_ratio=30,
                                                 min_token_text_ratio=3,
-                                                stream=stream):
-                tts_speech_token.append(i)
-                if len(tts_speech_token) == self.stream_win_len:
-                    this_tts_speech_token = torch.concat(tts_speech_token, dim=1)
+                                                stream=True):
+                self.tts_speech_token.append(i)
+        self.llm_end = True
+
+    def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
+        with self.flow_hift_context:
+            tts_mel = self.flow.inference(token=token.to(self.device),
+                                        token_len=torch.tensor([token.size(1)], dtype=torch.int32).to(self.device),
+                                        prompt_token=prompt_token.to(self.device),
+                                        prompt_token_len=prompt_token_len.to(self.device),
+                                        prompt_feat=prompt_feat.to(self.device),
+                                        prompt_feat_len=prompt_feat_len.to(self.device),
+                                        embedding=embedding.to(self.device))
+            tts_speech = self.hift.inference(mel=tts_mel).cpu()
+        return tts_speech
+
+    def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
+                  prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
+                  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
+                  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
+                  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
+        if stream is True:
+            self.tts_speech_token, self.llm_end, cache_speech = [], False, None
+            p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
+                                                     llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device)))
+            p.start()
+            while True:
+                time.sleep(0.1)
+                if len(self.tts_speech_token) >= self.stream_win_len:
+                    this_tts_speech_token = torch.concat(self.tts_speech_token[:self.stream_win_len], dim=1)
+                    with self.flow_hift_context:
+                        this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                                    prompt_token=flow_prompt_speech_token.to(self.device),
+                                                    prompt_token_len=flow_prompt_speech_token_len.to(self.device),
+                                                    prompt_feat=prompt_speech_feat.to(self.device),
+                                                    prompt_feat_len=prompt_speech_feat_len.to(self.device),
+                                                    embedding=flow_embedding.to(self.device))
+                    # fade in/out if necessary
+                    if cache_speech is not None:
+                        this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
+                    yield  {'tts_speech': this_tts_speech[:, :-self.overlap]}
+                    cache_speech = this_tts_speech[:, -self.overlap:]
+                    with self.lock:
+                        self.tts_speech_token = self.tts_speech_token[self.stream_hop_len:]
+                if self.llm_end is True:
+                    break
+            # deal with remain tokens
+            if cache_speech is None or len(self.tts_speech_token) > self.stream_win_len - self.stream_hop_len:
+                this_tts_speech_token = torch.concat(self.tts_speech_token, dim=1)
+                with self.flow_hift_context:
                     this_tts_mel = self.flow.inference(token=this_tts_speech_token,
                                                 token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
                                                 prompt_token=flow_prompt_speech_token.to(self.device),
@@ -68,29 +115,14 @@ class CosyVoiceModel:
                                                 prompt_feat_len=prompt_speech_feat_len.to(self.device),
                                                 embedding=flow_embedding.to(self.device))
                     this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
-                    # fade in/out if necessary
-                    if cache_speech is not None:
-                        this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
-                    yield  {'tts_speech': this_tts_speech[:, :-self.overlap]}
-                    cache_speech = this_tts_speech[:, -self.overlap:]
-                    tts_speech_token = tts_speech_token[-(self.stream_win_len - self.stream_hop_len):]
-            # deal with remain tokens
-            if cache_speech is None or len(tts_speech_token) > self.stream_win_len - self.stream_hop_len:
-                this_tts_speech_token = torch.concat(tts_speech_token, dim=1)
-                this_tts_mel = self.flow.inference(token=this_tts_speech_token,
-                                            token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
-                                            prompt_token=flow_prompt_speech_token.to(self.device),
-                                            prompt_token_len=flow_prompt_speech_token_len.to(self.device),
-                                            prompt_feat=prompt_speech_feat.to(self.device),
-                                            prompt_feat_len=prompt_speech_feat_len.to(self.device),
-                                            embedding=flow_embedding.to(self.device))
-                this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
                 if cache_speech is not None:
                     this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
                 yield {'tts_speech': this_tts_speech}
             else:
-                assert len(tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len)
+                assert len(self.tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len)
                 yield {'tts_speech': cache_speech}
+            p.join()
+            torch.cuda.synchronize()
         else:
             tts_speech_token = []
             for i in self.llm.inference(text=text.to(self.device),

+ 1 - 1
cosyvoice/flow/length_regulator.py

@@ -43,7 +43,7 @@ class InterpolateRegulator(nn.Module):
     def forward(self, x, ylens=None):
         # x in (B, T, D)
         mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
-        x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+        x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
         out = self.model(x).transpose(1, 2).contiguous()
         olens = ylens
         return out * mask, olens

+ 2 - 2
cosyvoice/llm/llm.py

@@ -174,7 +174,7 @@ class TransformerLM(torch.nn.Module):
             embedding = self.spk_embed_affine_layer(embedding)
             embedding = embedding.unsqueeze(dim=1)
         else:
-            embedding = torch.zeros(1, 0, self.llm_input_size).to(device)
+            embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
 
         # 3. concat llm_input
         sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
@@ -182,7 +182,7 @@ class TransformerLM(torch.nn.Module):
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
-            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size).to(device)
+            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
         lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
 
         # 4. cal min/max_length

+ 14 - 25
webui.py

@@ -24,14 +24,8 @@ import torchaudio
 import random
 import librosa
 
-import logging
-logging.getLogger('matplotlib').setLevel(logging.WARNING)
-
 from cosyvoice.cli.cosyvoice import CosyVoice
-from cosyvoice.utils.file_utils import load_wav, speed_change
-
-logging.basicConfig(level=logging.DEBUG,
-                    format='%(asctime)s %(levelname)s %(message)s')
+from cosyvoice.utils.file_utils import load_wav, speed_change, logging
 
 def generate_seed():
     seed = random.randint(1, 100000000)
@@ -63,10 +57,11 @@ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成
                  '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
                  '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
                  '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
+stream_mode_list = [('否', False), ('是', True)]
 def change_instruction(mode_checkbox_group):
     return instruct_dict[mode_checkbox_group]
 
-def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor):
+def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor):
     if prompt_wav_upload is not None:
         prompt_wav = prompt_wav_upload
     elif prompt_wav_record is not None:
@@ -117,32 +112,25 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
     if mode_checkbox_group == '预训练音色':
         logging.info('get sft inference request')
         set_all_random_seed(seed)
-        output = cosyvoice.inference_sft(tts_text, sft_dropdown)
+        for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
+            yield (target_sr,  i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '3s极速复刻':
         logging.info('get zero_shot inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
-        output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
+        for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
+            yield (target_sr,  i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '跨语种复刻':
         logging.info('get cross_lingual inference request')
         prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
-        output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
+        for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
+            yield (target_sr,  i['tts_speech'].numpy().flatten())
     else:
         logging.info('get instruct inference request')
         set_all_random_seed(seed)
-        output = cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text)
-    
-    if speed_factor != 1.0:
-        try:
-            audio_data, sample_rate = speed_change(output["tts_speech"], target_sr, str(speed_factor))
-            audio_data = audio_data.numpy().flatten()
-        except Exception as e:
-            print(f"Failed to change speed of audio: \n{e}")
-    else:
-        audio_data = output['tts_speech'].numpy().flatten()
-
-    return (target_sr, audio_data)
+        for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
+            yield (target_sr,  i['tts_speech'].numpy().flatten())
 
 def main():
     with gr.Blocks() as demo:
@@ -155,6 +143,7 @@ def main():
             mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
             instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
             sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
+            stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
             with gr.Column(scale=0.25):
                 seed_button = gr.Button(value="\U0001F3B2")
                 seed = gr.Number(value=0, label="随机推理种子")
@@ -167,11 +156,11 @@ def main():
 
         generate_button = gr.Button("生成音频")
 
-        audio_output = gr.Audio(label="合成音频")
+        audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
 
         seed_button.click(generate_seed, inputs=[], outputs=seed)
         generate_button.click(generate_audio,
-                              inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, speed_factor],
+                              inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor],
                               outputs=[audio_output])
         mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
     demo.queue(max_size=4, default_concurrency_limit=2)