1
0
禾息 1 vuosi sitten
vanhempi
commit
6e7f5b922a
3 muutettua tiedostoa jossa 12 lisäystä ja 49 poistoa
  1. 3 3
      cosyvoice/bin/export_trt.py
  2. 1 2
      cosyvoice/cli/model.py
  3. 8 44
      cosyvoice/flow/flow_matching.py

+ 3 - 3
cosyvoice/bin/export_trt.py

@@ -66,13 +66,13 @@ def main():
         opset_version=18,
         do_constant_folding=True,
         input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
-        output_names=['output'],
+        output_names=['estimator_out'],
         dynamic_axes={
             'x': {2: 'seq_len'},
             'mask': {2: 'seq_len'},
             'mu': {2: 'seq_len'},
             'cond': {2: 'seq_len'},
-            'output': {2: 'seq_len'},
+            'estimator_out': {2: 'seq_len'},
         }
     )
 
@@ -95,7 +95,7 @@ def main():
                   "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
                   "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \
                   ("--fp16" if args.export_half else "")
-# /ossfs/workspace/TensorRT-10.2.0.19/bin/trtexec --onnx=estimator_fp32.onnx --saveEngine=estimator_fp32.plan --minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 --maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose
+    
     print("execute ", trtexec_cmd)
 
     os.system(trtexec_cmd)

+ 1 - 2
cosyvoice/cli/model.py

@@ -83,8 +83,7 @@ class CosyVoiceModel:
         with open(trt_file_path, 'rb') as f:
             serialized_engine = f.read()
         engine = runtime.deserialize_cuda_engine(serialized_engine)
-        self.flow.decoder.estimator_context = engine.create_execution_context()
-        self.flow.decoder.estimator_engine = engine
+        self.flow.decoder.estimator = engine.create_execution_context()
 
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context:

+ 8 - 44
cosyvoice/flow/flow_matching.py

@@ -30,10 +30,6 @@ class ConditionalCFM(BASECFM):
         # Just change the architecture of the estimator here
         self.estimator = estimator
 
-        self.estimator_context = None
-        self.estimator_engine = None
-        self.is_saved = None
-
     @torch.inference_mode()
     def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
         """Forward diffusion
@@ -102,7 +98,11 @@ class ConditionalCFM(BASECFM):
         return sol[-1]
 
     def forward_estimator(self, x, mask, mu, t, spks, cond):
-        if self.estimator_context is not None:
+
+        if not isinstance(self.estimator, torch.nn.Module):
+            return self.estimator.forward(x, mask, mu, t, spks, cond)
+
+        else:
             assert self.training is False, 'tensorrt cannot be used in training'
             bs = x.shape[0]
             hs = x.shape[1]
@@ -116,50 +116,14 @@ class ConditionalCFM(BASECFM):
             self.estimator_context.set_input_shape("spks", spks.shape)
             self.estimator_context.set_input_shape("cond", cond.shape)
             bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
+            names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out']
             
             for i in range(len(bindings)):
-                self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i])
+                self.estimator.set_tensor_address(names[i], bindings[i])
 
             handle = torch.cuda.current_stream().cuda_stream
-            self.estimator_context.execute_async_v3(stream_handle=handle)
+            self.estimator.execute_async_v3(stream_handle=handle)
             return ret
-        else:
-
-            if self.is_saved == None:
-                self.is_saved = True
-                output = self.estimator.forward(x, mask, mu, t, spks, cond)
-                torch.save(x, "x.pt")
-                torch.save(mask, "mask.pt")
-                torch.save(mu, "mu.pt")
-                torch.save(t, "t.pt")
-                torch.save(spks, "spks.pt")
-                torch.save(cond, "cond.pt")
-                torch.save(output, "output.pt")
-                dummy_input = (x, mask, mu, t, spks, cond)
-                torch.onnx.export(
-                    self.estimator,
-                    dummy_input,
-                    "estimator_fp32.onnx",
-                    export_params=True,
-                    opset_version=17,
-                    do_constant_folding=True,
-                    input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
-                    output_names=['output'],
-                    dynamic_axes={
-                        'x': {2: 'seq_len'},
-                        'mask': {2: 'seq_len'},
-                        'mu': {2: 'seq_len'},
-                        'cond': {2: 'seq_len'},
-                        'output': {2: 'seq_len'},
-                    }
-                )
-            # print("x, x.shape", x, x.shape)
-            # print("mask, mask.shape", mask, mask.shape)
-            # print("mu, mu.shape", mu, mu.shape)
-            # print("t, t.shape", t, t.shape)
-            # print("spks, spks.shape", spks, spks.shape)
-            # print("cond, cond.shape", cond, cond.shape)
-            return self.estimator.forward(x, mask, mu, t, spks, cond)
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         """Computes diffusion loss