export_trt.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import logging
  16. import os
  17. import sys
  18. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  19. try:
  20. import tensorrt
  21. except ImportError:
  22. error_msg_zh = [
  23. "step.1 下载 tensorrt .tar.gz 压缩包并解压,下载地址: https://developer.nvidia.com/tensorrt/download/10x",
  24. "step.2 使用 tensorrt whl 包进行安装根据 python 版本对应进行安装,如 pip install ${TensorRT-Path}/python/tensorrt-10.2.0-cp38-none-linux_x86_64.whl",
  25. "step.3 将 tensorrt 的 lib 路径添加进环境变量中,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${TensorRT-Path}/lib/"
  26. ]
  27. print("\n".join(error_msg_zh))
  28. sys.exit(1)
  29. import torch
  30. from cosyvoice.cli.cosyvoice import CosyVoice
  31. def get_args():
  32. parser = argparse.ArgumentParser(description='Export your model for deployment')
  33. parser.add_argument('--model_dir',
  34. type=str,
  35. default='pretrained_models/CosyVoice-300M-SFT',
  36. help='Local path to the model directory')
  37. parser.add_argument('--export_half',
  38. action='store_true',
  39. help='Export with half precision (FP16)')
  40. args = parser.parse_args()
  41. print(args)
  42. return args
  43. def main():
  44. args = get_args()
  45. cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
  46. estimator = cosyvoice.model.flow.decoder.estimator
  47. dtype = torch.float32 if not args.export_half else torch.float16
  48. device = torch.device("cuda")
  49. batch_size = 1
  50. seq_len = 256
  51. hidden_size = cosyvoice.model.flow.output_size
  52. x = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
  53. mask = torch.ones((batch_size, 1, seq_len), dtype=dtype, device=device)
  54. mu = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
  55. t = torch.rand((batch_size, ), dtype=dtype, device=device)
  56. spks = torch.rand((batch_size, hidden_size), dtype=dtype, device=device)
  57. cond = torch.rand((batch_size, hidden_size, seq_len), dtype=dtype, device=device)
  58. onnx_file_name = 'estimator_fp32.onnx' if not args.export_half else 'estimator_fp16.onnx'
  59. onnx_file_path = os.path.join(args.model_dir, onnx_file_name)
  60. dummy_input = (x, mask, mu, t, spks, cond)
  61. estimator = estimator.to(dtype)
  62. torch.onnx.export(
  63. estimator,
  64. dummy_input,
  65. onnx_file_path,
  66. export_params=True,
  67. opset_version=18,
  68. do_constant_folding=True,
  69. input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
  70. output_names=['estimator_out'],
  71. dynamic_axes={
  72. 'x': {2: 'seq_len'},
  73. 'mask': {2: 'seq_len'},
  74. 'mu': {2: 'seq_len'},
  75. 'cond': {2: 'seq_len'},
  76. 'estimator_out': {2: 'seq_len'},
  77. }
  78. )
  79. tensorrt_path = os.environ.get('tensorrt_root_dir')
  80. if not tensorrt_path:
  81. raise EnvironmentError("Please set the 'tensorrt_root_dir' environment variable.")
  82. if not os.path.isdir(tensorrt_path):
  83. raise FileNotFoundError(f"The directory {tensorrt_path} does not exist.")
  84. trt_lib_path = os.path.join(tensorrt_path, "lib")
  85. if trt_lib_path not in os.environ.get('LD_LIBRARY_PATH', ''):
  86. print(f"Adding TensorRT lib path {trt_lib_path} to LD_LIBRARY_PATH.")
  87. os.environ['LD_LIBRARY_PATH'] = f"{os.environ.get('LD_LIBRARY_PATH', '')}:{trt_lib_path}"
  88. trt_file_name = 'estimator_fp32.plan' if not args.export_half else 'estimator_fp16.plan'
  89. trt_file_path = os.path.join(args.model_dir, trt_file_name)
  90. trtexec_bin = os.path.join(tensorrt_path, 'bin/trtexec')
  91. trtexec_cmd = f"{trtexec_bin} --onnx={onnx_file_path} --saveEngine={trt_file_path} " \
  92. "--minShapes=x:1x80x1,mask:1x1x1,mu:1x80x1,t:1,spks:1x80,cond:1x80x1 " \
  93. "--maxShapes=x:1x80x4096,mask:1x1x4096,mu:1x80x4096,t:1,spks:1x80,cond:1x80x4096 --verbose " + \
  94. ("--fp16" if args.export_half else "")
  95. print("execute ", trtexec_cmd)
  96. os.system(trtexec_cmd)
  97. # print("x.shape", x.shape)
  98. # print("mask.shape", mask.shape)
  99. # print("mu.shape", mu.shape)
  100. # print("t.shape", t.shape)
  101. # print("spks.shape", spks.shape)
  102. # print("cond.shape", cond.shape)
  103. if __name__ == "__main__":
  104. main()