lyuxiang.lx пре 1 година
родитељ
комит
a13411c561
4 измењених фајлова са 123 додато и 44 уклоњено
  1. 30 21
      cosyvoice/cli/cosyvoice.py
  2. 82 22
      cosyvoice/cli/model.py
  3. 7 1
      cosyvoice/llm/llm.py
  4. 4 0
      cosyvoice/utils/file_utils.py

+ 30 - 21
cosyvoice/cli/cosyvoice.py

@@ -12,11 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import os
-import torch
+import time
 from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 from cosyvoice.cli.model import CosyVoiceModel
+from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
@@ -44,40 +45,48 @@ class CosyVoice:
         spks = list(self.frontend.spk2info.keys())
         return spks
 
-    def inference_sft(self, tts_text, spk_id):
-        tts_speeches = []
+    def inference_sft(self, tts_text, spk_id, stream=False):
+        start_time = time.time()
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_sft(i, spk_id)
-            model_output = self.model.inference(**model_input)
-            tts_speeches.append(model_output['tts_speech'])
-        return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+            for model_output in self.model.inference(**model_input, stream=stream):
+                speech_len = model_output['tts_speech'].shape[1] / 22050
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()
 
-    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
+    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
+        start_time = time.time()
         prompt_text = self.frontend.text_normalize(prompt_text, split=False)
-        tts_speeches = []
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
-            model_output = self.model.inference(**model_input)
-            tts_speeches.append(model_output['tts_speech'])
-        return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+            for model_output in self.model.inference(**model_input, stream=stream):
+                speech_len = model_output['tts_speech'].shape[1] / 22050
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()
 
-    def inference_cross_lingual(self, tts_text, prompt_speech_16k):
+    def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
         if self.frontend.instruct is True:
             raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
-        tts_speeches = []
+        start_time = time.time()
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
-            model_output = self.model.inference(**model_input)
-            tts_speeches.append(model_output['tts_speech'])
-        return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+            for model_output in self.model.inference(**model_input, stream=stream):
+                speech_len = model_output['tts_speech'].shape[1] / 22050
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()
 
-    def inference_instruct(self, tts_text, spk_id, instruct_text):
+    def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
         if self.frontend.instruct is False:
             raise ValueError('{} do not support instruct inference'.format(self.model_dir))
+        start_time = time.time()
         instruct_text = self.frontend.text_normalize(instruct_text, split=False)
-        tts_speeches = []
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
-            model_output = self.model.inference(**model_input)
-            tts_speeches.append(model_output['tts_speech'])
-        return {'tts_speech': torch.concat(tts_speeches, dim=1)}
+            for model_output in self.model.inference(**model_input, stream=stream):
+                speech_len = model_output['tts_speech'].shape[1] / 22050
+                logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+                yield model_output
+                start_time = time.time()

+ 82 - 22
cosyvoice/cli/model.py

@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import torch
+import numpy as np
+
 
 class CosyVoiceModel:
 
@@ -23,6 +25,10 @@ class CosyVoiceModel:
         self.llm = llm
         self.flow = flow
         self.hift = hift
+        self.stream_win_len = 60
+        self.stream_hop_len = 50
+        self.overlap = 4395 # 10 token equals 4395 sample point
+        self.window = np.hamming(2 * self.overlap)
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -36,25 +42,79 @@ class CosyVoiceModel:
                   prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
                   llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
                   flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
-                  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
-        tts_speech_token = self.llm.inference(text=text.to(self.device),
-                                              text_len=text_len.to(self.device),
-                                              prompt_text=prompt_text.to(self.device),
-                                              prompt_text_len=prompt_text_len.to(self.device),
-                                              prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                              prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
-                                              embedding=llm_embedding.to(self.device),
-                                              beam_size=1,
-                                              sampling=25,
-                                              max_token_text_ratio=30,
-                                              min_token_text_ratio=3)
-        tts_mel = self.flow.inference(token=tts_speech_token,
-                                      token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
-                                      prompt_token=flow_prompt_speech_token.to(self.device),
-                                      prompt_token_len=flow_prompt_speech_token_len.to(self.device),
-                                      prompt_feat=prompt_speech_feat.to(self.device),
-                                      prompt_feat_len=prompt_speech_feat_len.to(self.device),
-                                      embedding=flow_embedding.to(self.device))
-        tts_speech = self.hift.inference(mel=tts_mel).cpu()
-        torch.cuda.empty_cache()
-        return {'tts_speech': tts_speech}
+                  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
+        if stream is True:
+            tts_speech_token, cache_speech = [], None
+            for i in self.llm.inference(text=text.to(self.device),
+                                                text_len=text_len.to(self.device),
+                                                prompt_text=prompt_text.to(self.device),
+                                                prompt_text_len=prompt_text_len.to(self.device),
+                                                prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                                prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
+                                                embedding=llm_embedding.to(self.device),
+                                                beam_size=1,
+                                                sampling=25,
+                                                max_token_text_ratio=30,
+                                                min_token_text_ratio=3,
+                                                stream=stream):
+                tts_speech_token.append(i)
+                if len(tts_speech_token) == self.stream_win_len:
+                    this_tts_speech_token = torch.concat(tts_speech_token, dim=1)
+                    this_tts_mel = self.flow.inference(token=this_tts_speech_token,
+                                                token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
+                                                prompt_token=flow_prompt_speech_token.to(self.device),
+                                                prompt_token_len=flow_prompt_speech_token_len.to(self.device),
+                                                prompt_feat=prompt_speech_feat.to(self.device),
+                                                prompt_feat_len=prompt_speech_feat_len.to(self.device),
+                                                embedding=flow_embedding.to(self.device))
+                    this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
+                    # fade in/out if necessary
+                    if cache_speech is not None:
+                        this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
+                    yield  {'tts_speech': this_tts_speech[:, :-self.overlap]}
+                    cache_speech = this_tts_speech[:, -self.overlap:]
+                    tts_speech_token = tts_speech_token[-(self.stream_win_len - self.stream_hop_len):]
+            # deal with remain tokens
+            if cache_speech is None or len(tts_speech_token) > self.stream_win_len - self.stream_hop_len:
+                this_tts_speech_token = torch.concat(tts_speech_token, dim=1)
+                this_tts_mel = self.flow.inference(token=this_tts_speech_token,
+                                            token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
+                                            prompt_token=flow_prompt_speech_token.to(self.device),
+                                            prompt_token_len=flow_prompt_speech_token_len.to(self.device),
+                                            prompt_feat=prompt_speech_feat.to(self.device),
+                                            prompt_feat_len=prompt_speech_feat_len.to(self.device),
+                                            embedding=flow_embedding.to(self.device))
+                this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
+                if cache_speech is not None:
+                    this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
+                yield {'tts_speech': this_tts_speech}
+            else:
+                assert len(tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len)
+                yield {'tts_speech': cache_speech}
+        else:
+            tts_speech_token = []
+            for i in self.llm.inference(text=text.to(self.device),
+                                                text_len=text_len.to(self.device),
+                                                prompt_text=prompt_text.to(self.device),
+                                                prompt_text_len=prompt_text_len.to(self.device),
+                                                prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                                prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
+                                                embedding=llm_embedding.to(self.device),
+                                                beam_size=1,
+                                                sampling=25,
+                                                max_token_text_ratio=30,
+                                                min_token_text_ratio=3,
+                                                stream=stream):
+                tts_speech_token.append(i)
+            assert len(tts_speech_token) == 1, 'tts_speech_token len should be 1 when stream is {}'.format(stream)
+            tts_speech_token = torch.concat(tts_speech_token, dim=1)
+            tts_mel = self.flow.inference(token=tts_speech_token,
+                                        token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
+                                        prompt_token=flow_prompt_speech_token.to(self.device),
+                                        prompt_token_len=flow_prompt_speech_token_len.to(self.device),
+                                        prompt_feat=prompt_speech_feat.to(self.device),
+                                        prompt_feat_len=prompt_speech_feat_len.to(self.device),
+                                        embedding=flow_embedding.to(self.device))
+            tts_speech = self.hift.inference(mel=tts_mel).cpu()
+            torch.cuda.empty_cache()
+            yield {'tts_speech': tts_speech}

+ 7 - 1
cosyvoice/llm/llm.py

@@ -158,6 +158,7 @@ class TransformerLM(torch.nn.Module):
             sampling: int = 25,
             max_token_text_ratio: float = 20,
             min_token_text_ratio: float = 2,
+            stream: bool = False,
     ) -> torch.Tensor:
         device = text.device
         text = torch.concat([prompt_text, text], dim=1)
@@ -199,8 +200,13 @@ class TransformerLM(torch.nn.Module):
             top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
             if top_ids == self.speech_token_size:
                 break
+            # in stream mode, yield token one by one
+            if stream is True:
+                yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
             out_tokens.append(top_ids)
             offset += lm_input.size(1)
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
 
-        return torch.tensor([out_tokens], dtype=torch.int64, device=device)
+        # in non-stream mode, yield all token
+        if stream is False:
+            yield torch.tensor([out_tokens], dtype=torch.int64, device=device)

+ 4 - 0
cosyvoice/utils/file_utils.py

@@ -15,6 +15,10 @@
 
 import json
 import torchaudio
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+logging.basicConfig(level=logging.DEBUG,
+                    format='%(asctime)s %(levelname)s %(message)s')
 
 
 def read_lists(list_file):