1
0
Selaa lähdekoodia

add trt script TODO

lyuxiang.lx 1 vuosi sitten
vanhempi
commit
f1e374a9bb

+ 1 - 1
cosyvoice/bin/export.py → cosyvoice/bin/export_jit.py

@@ -44,7 +44,7 @@ def main():
     torch._C._jit_set_profiling_mode(False)
     torch._C._jit_set_profiling_executor(False)
 
-    cosyvoice = CosyVoice(args.model_dir, load_script=False)
+    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
 
     # 1. export llm text_encoder
     llm_text_encoder = cosyvoice.model.llm.text_encoder.half()

+ 8 - 0
cosyvoice/bin/export_trt.py

@@ -0,0 +1,8 @@
+# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
+# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
+try:
+    import tensorrt
+except ImportError:
+    print('step1, 下载\n step2. 解压,安装whl,')
+# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
+# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx

+ 6 - 3
cosyvoice/cli/cosyvoice.py

@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_script=True):
+    def __init__(self, model_dir, load_jit=True, load_trt=True):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
@@ -39,9 +39,12 @@ class CosyVoice:
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
-        if load_script:
-            self.model.load_script('{}/llm.text_encoder.fp16.zip'.format(model_dir),
+        if load_jit:
+            self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
                                     '{}/llm.llm.fp16.zip'.format(model_dir))
+        if load_trt:
+            # TODO
+            self.model.load_trt()
         del configs
 
     def list_avaliable_spks(self):

+ 6 - 1
cosyvoice/cli/model.py

@@ -53,12 +53,17 @@ class CosyVoiceModel:
         self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
         self.hift.to(self.device).eval()
 
-    def load_script(self, llm_text_encoder_model, llm_llm_model):
+    def load_jit(self, llm_text_encoder_model, llm_llm_model):
         llm_text_encoder = torch.jit.load(llm_text_encoder_model)
         self.llm.text_encoder = llm_text_encoder
         llm_llm = torch.jit.load(llm_llm_model)
         self.llm.llm = llm_llm
 
+    def load_trt(self):
+        # TODO 你需要的TRT推理的准备
+        self.flow.decoder.estimator = xxx
+        self.flow.decoder.session = xxx
+
     def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
         with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),

+ 10 - 2
cosyvoice/flow/flow_matching.py

@@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM):
         sol = []
 
         for step in range(1, len(t_span)):
-            dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
+            dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
             # Classifier-Free Guidance inference introduced in VoiceBox
             if self.inference_cfg_rate > 0:
-                cfg_dphi_dt = self.estimator(
+                cfg_dphi_dt = self.forward_estimator(
                     x, mask,
                     torch.zeros_like(mu), t,
                     torch.zeros_like(spks) if spks is not None else None,
@@ -96,6 +96,14 @@ class ConditionalCFM(BASECFM):
 
         return sol[-1]
 
+    # TODO
+    def forward_estimator(self):
+        if isinstance(self.estimator, trt):
+            assert self.training is False, 'tensorrt cannot be used in training'
+            return xxx
+        else:
+            return self.estimator.forward
+
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         """Computes diffusion loss