Ver Fonte

set onnx to false as last chunk rtf unstable

lyuxiang.lx há 1 ano atrás
pai
commit
122df8c420
3 ficheiros alterados com 45 adições e 51 exclusões
  1. 1 0
      .github/workflows/lint.yml
  2. 1 1
      cosyvoice/cli/cosyvoice.py
  3. 43 50
      cosyvoice/cli/model.py

+ 1 - 0
.github/workflows/lint.yml

@@ -2,6 +2,7 @@ name: Lint
 
 on:
   pull_request:
+  push:
 
 jobs:
   quick-checks:

+ 1 - 1
cosyvoice/cli/cosyvoice.py

@@ -23,7 +23,7 @@ from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=True, load_onnx=True):
+    def __init__(self, model_dir, load_jit=True, load_onnx=False):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):

+ 43 - 50
cosyvoice/cli/model.py

@@ -43,7 +43,6 @@ class CosyVoiceModel:
         self.stream_scale_factor = 1
         assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
-        self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.lock = threading.Lock()
         # dict used to store session related variable
         self.tts_speech_token_dict = {}
@@ -93,32 +92,31 @@ class CosyVoiceModel:
         self.llm_end_dict[uuid] = True
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
-        with self.flow_hift_context:
-            tts_mel = self.flow.inference(token=token.to(self.device),
-                                          token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
-                                          prompt_token=prompt_token.to(self.device),
-                                          prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
-                                          prompt_feat=prompt_feat.to(self.device),
-                                          prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
-                                          embedding=embedding.to(self.device))
-            # mel overlap fade in out
-            if self.mel_overlap_dict[uuid] is not None:
-                tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
-            # append hift cache
-            if self.hift_cache_dict[uuid] is not None:
-                hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
-                tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
-            else:
-                hift_cache_source = torch.zeros(1, 1, 0)
-            # keep overlap mel and hift cache
-            if finalize is False:
-                self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
-                tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
-                tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
-                self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
-                tts_speech = tts_speech[:, :-self.source_cache_len]
-            else:
-                tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
+        tts_mel = self.flow.inference(token=token.to(self.device),
+                                      token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                      prompt_token=prompt_token.to(self.device),
+                                      prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                      prompt_feat=prompt_feat.to(self.device),
+                                      prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                      embedding=embedding.to(self.device))
+        # mel overlap fade in out
+        if self.mel_overlap_dict[uuid] is not None:
+            tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
+        # append hift cache
+        if self.hift_cache_dict[uuid] is not None:
+            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+        else:
+            hift_cache_source = torch.zeros(1, 1, 0)
+        # keep overlap mel and hift cache
+        if finalize is False:
+            self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
+            tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
+            tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
+            self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
+            tts_speech = tts_speech[:, :-self.source_cache_len]
+        else:
+            tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
         return tts_speech
 
     def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
@@ -139,13 +137,12 @@ class CosyVoiceModel:
                 time.sleep(0.1)
                 if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
                     this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
-                    with self.flow_hift_context:
-                        this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                                         prompt_token=flow_prompt_speech_token,
-                                                         prompt_feat=prompt_speech_feat,
-                                                         embedding=flow_embedding,
-                                                         uuid=this_uuid,
-                                                         finalize=False)
+                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                                     prompt_token=flow_prompt_speech_token,
+                                                     prompt_feat=prompt_speech_feat,
+                                                     embedding=flow_embedding,
+                                                     uuid=this_uuid,
+                                                     finalize=False)
                     yield {'tts_speech': this_tts_speech.cpu()}
                     with self.lock:
                         self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
@@ -156,30 +153,26 @@ class CosyVoiceModel:
             p.join()
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
-            with self.flow_hift_context:
-                this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                                 prompt_token=flow_prompt_speech_token,
-                                                 prompt_feat=prompt_speech_feat,
-                                                 embedding=flow_embedding,
-                                                 uuid=this_uuid,
-                                                 finalize=True)
+            this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                             prompt_token=flow_prompt_speech_token,
+                                             prompt_feat=prompt_speech_feat,
+                                             embedding=flow_embedding,
+                                             uuid=this_uuid,
+                                             finalize=True)
             yield {'tts_speech': this_tts_speech.cpu()}
         else:
             # deal with all tokens
             p.join()
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
-            with self.flow_hift_context:
-                this_tts_speech = self.token2wav(token=this_tts_speech_token,
-                                                 prompt_token=flow_prompt_speech_token,
-                                                 prompt_feat=prompt_speech_feat,
-                                                 embedding=flow_embedding,
-                                                 uuid=this_uuid,
-                                                 finalize=True)
+            this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                             prompt_token=flow_prompt_speech_token,
+                                             prompt_feat=prompt_speech_feat,
+                                             embedding=flow_embedding,
+                                             uuid=this_uuid,
+                                             finalize=True)
             yield {'tts_speech': this_tts_speech.cpu()}
         with self.lock:
             self.tts_speech_token_dict.pop(this_uuid)
             self.llm_end_dict.pop(this_uuid)
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
-        if torch.cuda.is_available():
-            torch.cuda.synchronize()