1
0
zhoubofan.zbf 1 vuosi sitten
vanhempi
commit
53a3c1b17f

+ 17 - 9
cosyvoice/bin/export_trt.py

@@ -38,23 +38,21 @@ def main():
     args = get_args()
 
     cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
-
-    flow = cosyvoice.model.flow
     estimator = cosyvoice.model.flow.decoder.estimator
 
     dtype = torch.float32 if not args.export_half else torch.float16
     device = torch.device("cuda")
     batch_size = 1
-    seq_len = 1024
-    hidden_size = flow.output_size
+    seq_len = 256
+    hidden_size = cosyvoice.model.flow.output_size
     x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
-    mask = torch.zeros((batch_size, 1, seq_len), dtype=dtype, device=device)
+    mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device)
     mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
-    t = torch.tensor([0.], dtype=dtype, device=device)
+    t = torch.rand((batch_size, ), dtype=dtype, device=device)
     spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device)
     cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
 
-    onnx_file_name = 'estimator_fp16.onnx' if args.export_half else 'estimator_fp32.onnx'
+    onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx'
     onnx_file_path = os.path.join(args.model_dir, onnx_file_name)
     dummy_input = (x, mask, mu, t, spks, cond)
 
@@ -90,14 +88,24 @@ def main():
         print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
         os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
 
-    trt_file_name = 'estimator_fp16.plan' if args.export_half else 'estimator_fp32.plan'
+    trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan'
     trt_file_path = os.path.join(args.model_dir, trt_file_name)
 
     trtexec_cmd = f"{tensorrt_path}/bin/trtexec --onnx={onnx_file_path} --saveEngine={trt_file_path} " \
                   "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
-                  "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose"
+                  "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \
+                  ("--fp16" if args.export_half else "")
+# /ossfs/workspace/TensorRT-10.2.0.19/bin/trtexec --onnx=estimator_fp32.onnx --saveEngine=estimator_fp32.plan --minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 --maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose
+    print("execute ", trtexec_cmd)
 
     os.system(trtexec_cmd)
 
+    print("x.shape", x.shape)
+    print("mask.shape", mask.shape)
+    print("mu.shape", mu.shape)
+    print("t.shape", t.shape)
+    print("spks.shape", spks.shape)
+    print("cond.shape", cond.shape)
+
 if __name__ == "__main__":
     main()

+ 4 - 1
cosyvoice/cli/cosyvoice.py

@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=True, load_trt=True, use_fp16=False):
+    def __init__(self, model_dir, load_jit=True, load_trt=False, use_fp16=False):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
@@ -39,11 +39,14 @@ class CosyVoice:
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
+        load_jit = False
         if load_jit:
             self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
                                     '{}/llm.llm.fp16.zip'.format(model_dir))
+
         if load_trt:
             self.model.load_trt(model_dir, use_fp16)
+            
         del configs
 
     def list_avaliable_spks(self):

+ 1 - 1
cosyvoice/flow/flow.py

@@ -107,7 +107,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         # concat text and prompt_text
         token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
         token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
-        mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
+        mask = (~make_pad_mask(token_len)).to(embedding.dtype).unsqueeze(-1).to(embedding)
         token = self.input_embedding(torch.clamp(token, min=0)) * mask
 
         # text encode

+ 36 - 0
cosyvoice/flow/flow_matching.py

@@ -32,6 +32,7 @@ class ConditionalCFM(BASECFM):
 
         self.estimator_context = None
         self.estimator_engine = None
+        self.is_saved = None
 
     @torch.inference_mode()
     def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
@@ -123,6 +124,41 @@ class ConditionalCFM(BASECFM):
             self.estimator_context.execute_async_v3(stream_handle=handle)
             return ret
         else:
+
+            if self.is_saved == None:
+                self.is_saved = True
+                output = self.estimator.forward(x, mask, mu, t, spks, cond)
+                torch.save(x, "x.pt")
+                torch.save(mask, "mask.pt")
+                torch.save(mu, "mu.pt")
+                torch.save(t, "t.pt")
+                torch.save(spks, "spks.pt")
+                torch.save(cond, "cond.pt")
+                torch.save(output, "output.pt")
+                dummy_input = (x, mask, mu, t, spks, cond)
+                torch.onnx.export(
+                    self.estimator,
+                    dummy_input,
+                    "estimator_fp32.onnx",
+                    export_params=True,
+                    opset_version=17,
+                    do_constant_folding=True,
+                    input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
+                    output_names=['output'],
+                    dynamic_axes={
+                        'x': {2: 'seq_len'},
+                        'mask': {2: 'seq_len'},
+                        'mu': {2: 'seq_len'},
+                        'cond': {2: 'seq_len'},
+                        'output': {2: 'seq_len'},
+                    }
+                )
+            # print("x, x.shape", x, x.shape)
+            # print("mask, mask.shape", mask, mask.shape)
+            # print("mu, mu.shape", mu, mu.shape)
+            # print("t, t.shape", t, t.shape)
+            # print("spks, spks.shape", spks, spks.shape)
+            # print("cond, cond.shape", cond, cond.shape)
             return self.estimator.forward(x, mask, mu, t, spks, cond)
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):