|
|
@@ -1,8 +1,103 @@
|
|
|
-# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
|
|
|
-# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
|
|
|
+import argparse
|
|
|
+import logging
|
|
|
+import os
|
|
|
+import sys
|
|
|
+
|
|
|
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
|
+
|
|
|
try:
|
|
|
import tensorrt
|
|
|
except ImportError:
|
|
|
- print('step1, 下载\n step2. 解压,安装whl,')
|
|
|
-# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
|
|
|
-# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
|
|
|
+ error_msg_zh = [
|
|
|
+ "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x",
|
|
|
+ "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl",
|
|
|
+ "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=${TensorRT-Path}/lib/"
|
|
|
+ ]
|
|
|
+ print("\n".join(error_msg_zh))
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+import torch
|
|
|
+from cosyvoice.cli.cosyvoice import CosyVoice
|
|
|
+
|
|
|
+def get_args():
|
|
|
+ parser = argparse.ArgumentParser(description='Export your model for deployment')
|
|
|
+ parser.add_argument('--model_dir',
|
|
|
+ type=str,
|
|
|
+ default='pretrained_models/CosyVoice-300M',
|
|
|
+ help='Local path to the model directory')
|
|
|
+
|
|
|
+ parser.add_argument('--export_half',
|
|
|
+ action='store_true',
|
|
|
+ help='Export with half precision (FP16)')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+ print(args)
|
|
|
+ return args
|
|
|
+
|
|
|
+def main():
|
|
|
+ args = get_args()
|
|
|
+
|
|
|
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
|
|
|
+
|
|
|
+ flow = cosyvoice.model.flow
|
|
|
+ estimator = cosyvoice.model.flow.decoder.estimator
|
|
|
+
|
|
|
+ dtype = torch.float32 if not args.export_half else torch.float16
|
|
|
+ device = torch.device("cuda")
|
|
|
+ batch_size = 1
|
|
|
+ seq_len = 1024
|
|
|
+ hidden_size = flow.output_size
|
|
|
+ x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
|
|
|
+ mask = torch.zeros((batch_size, 1, seq_len), dtype=dtype, device=device)
|
|
|
+ mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
|
|
|
+ t = torch.tensor([0.], dtype=dtype, device=device)
|
|
|
+ spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device)
|
|
|
+ cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
|
|
|
+
|
|
|
+ onnx_file_name = 'estimator_fp16.onnx' if args.export_half else 'estimator_fp32.onnx'
|
|
|
+ onnx_file_path = os.path.join(args.model_dir, onnx_file_name)
|
|
|
+ dummy_input = (x, mask, mu, t, spks, cond)
|
|
|
+
|
|
|
+ estimator = estimator.to(dtype)
|
|
|
+
|
|
|
+ torch.onnx.export(
|
|
|
+ estimator,
|
|
|
+ dummy_input,
|
|
|
+ onnx_file_path,
|
|
|
+ export_params=True,
|
|
|
+ opset_version=18,
|
|
|
+ do_constant_folding=True,
|
|
|
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
|
|
+ output_names=['output'],
|
|
|
+ dynamic_axes={
|
|
|
+ 'x': {2: 'seq_len'},
|
|
|
+ 'mask': {2: 'seq_len'},
|
|
|
+ 'mu': {2: 'seq_len'},
|
|
|
+ 'cond': {2: 'seq_len'},
|
|
|
+ 'output': {2: 'seq_len'},
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ tensorrt_path = os.environ.get('tensorrt_root_dir')
|
|
|
+ if not tensorrt_path:
|
|
|
+ raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.")
|
|
|
+
|
|
|
+ if not os.path.isdir(tensorrt_path):
|
|
|
+ raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.")
|
|
|
+
|
|
|
+ trt_lib_path = os.path.join(tensorrt_path, "lib")
|
|
|
+ if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''):
|
|
|
+ print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
|
|
|
+ os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
|
|
|
+
|
|
|
+ trt_file_name = 'estimator_fp16.plan' if args.export_half else 'estimator_fp32.plan'
|
|
|
+ trt_file_path = os.path.join(args.model_dir, trt_file_name)
|
|
|
+
|
|
|
+ trtexec_cmd = f"{tensorrt_path}/bin/trtexec --onnx={onnx_file_path} --saveEngine={trt_file_path} " \
|
|
|
+ "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
|
|
|
+ "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose"
|
|
|
+
|
|
|
+ os.system(trtexec_cmd)
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|