1
0
Selaa lähdekoodia

revert trt TODO

lyuxiang.lx 1 vuosi sitten
vanhempi
commit
1ab3186799
3 muutettua tiedostoa jossa 5 lisäystä ja 22 poistoa
  1. 1 4
      cosyvoice/cli/cosyvoice.py
  2. 2 8
      cosyvoice/cli/model.py
  3. 2 10
      cosyvoice/flow/flow_matching.py

+ 1 - 4
cosyvoice/cli/cosyvoice.py

@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=True, load_trt=True):
+    def __init__(self, model_dir, load_jit=True):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
@@ -42,9 +42,6 @@ class CosyVoice:
         if load_jit:
             self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
                                     '{}/llm.llm.fp16.zip'.format(model_dir))
-        if load_trt:
-            # TODO
-            self.model.load_trt()
         del configs
 
     def list_avaliable_spks(self):

+ 2 - 8
cosyvoice/cli/model.py

@@ -66,11 +66,6 @@ class CosyVoiceModel:
         llm_llm = torch.jit.load(llm_llm_model)
         self.llm.llm = llm_llm
 
-    def load_trt(self):
-        # TODO 你需要的TRT推理的准备
-        self.flow.decoder.estimator = xxx
-        self.flow.decoder.session = xxx
-
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),
@@ -126,7 +121,6 @@ class CosyVoiceModel:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
         p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         p.start()
-        p.join()
         if stream is True:
             token_hop_len = self.token_min_hop_len
             while True:
@@ -147,7 +141,7 @@ class CosyVoiceModel:
                     token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
                 if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
                     break
-            # p.join()
+            p.join()
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
@@ -160,7 +154,7 @@ class CosyVoiceModel:
             yield {'tts_speech': this_tts_speech.cpu()}
         else:
             # deal with all tokens
-            # p.join()
+            p.join()
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
                 this_tts_speech = self.token2wav(token=this_tts_speech_token,

+ 2 - 10
cosyvoice/flow/flow_matching.py

@@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM):
         sol = []
 
         for step in range(1, len(t_span)):
-            dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
+            dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
             # Classifier-Free Guidance inference introduced in VoiceBox
             if self.inference_cfg_rate > 0:
-                cfg_dphi_dt = self.forward_estimator(
+                cfg_dphi_dt = self.estimator(
                     x, mask,
                     torch.zeros_like(mu), t,
                     torch.zeros_like(spks) if spks is not None else None,
@@ -96,14 +96,6 @@ class ConditionalCFM(BASECFM):
 
         return sol[-1]
 
-    # TODO
-    def forward_estimator(self):
-        if isinstance(self.estimator, trt):
-            assert self.training is False, 'tensorrt cannot be used in training'
-            return xxx
-        else:
-            return self.estimator.forward
-
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         """Computes diffusion loss