1
0

estimator_trt.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import os
  2. import torch
  3. import tensorrt as trt
  4. import logging
  5. import threading
  6. _min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
  7. _opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
  8. _max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
  9. class EstimatorTRT:
  10. def __init__(self, path_prefix: str, device: torch.device, fp16: bool = True):
  11. self.lock = threading.Lock()
  12. self.device = device
  13. with torch.cuda.device(device):
  14. self.input_names = ["x", "mask", "mu", "t", "spks", "cond"]
  15. self.output_name = "estimator_out"
  16. onnx_path = path_prefix + ".fp32.onnx"
  17. precision = ".fp16" if fp16 else ".fp32"
  18. trt_path = path_prefix + precision +".plan"
  19. self.fp16 = fp16
  20. self.logger = trt.Logger(trt.Logger.INFO)
  21. self.trt_runtime = trt.Runtime(self.logger)
  22. save_trt = not os.environ.get("NOT_SAVE_TRT", "0") == "1"
  23. if os.path.exists(trt_path):
  24. self.engine = self._load_trt(trt_path)
  25. else:
  26. self.engine = self._convert_onnx_to_trt(onnx_path, trt_path, save_trt)
  27. self.context = self.engine.create_execution_context()
  28. def _convert_onnx_to_trt(
  29. self, onnx_path: str, trt_path: str, save_trt: bool = True
  30. ):
  31. logging.info("Converting onnx to trt...")
  32. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  33. builder = trt.Builder(self.logger)
  34. network = builder.create_network(network_flags)
  35. parser = trt.OnnxParser(network, self.logger)
  36. config = builder.create_builder_config()
  37. config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
  38. if (self.fp16):
  39. config.set_flag(trt.BuilderFlag.FP16)
  40. profile = builder.create_optimization_profile()
  41. # load onnx model
  42. with open(onnx_path, "rb") as f:
  43. if not parser.parse(f.read()):
  44. for error in range(parser.num_errors):
  45. print(parser.get_error(error))
  46. exit(1)
  47. # set input shapes
  48. for i in range(len(self.input_names)):
  49. profile.set_shape(
  50. self.input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i]
  51. )
  52. tensor_dtype = trt.DataType.HALF if self.fp16 else trt.DataType.FLOAT
  53. # set input and output data type
  54. for i in range(network.num_inputs):
  55. input_tensor = network.get_input(i)
  56. input_tensor.dtype = tensor_dtype
  57. for i in range(network.num_outputs):
  58. output_tensor = network.get_output(i)
  59. output_tensor.dtype = tensor_dtype
  60. config.add_optimization_profile(profile)
  61. engine_bytes = builder.build_serialized_network(network, config)
  62. # save trt engine
  63. if save_trt:
  64. with open(trt_path, "wb") as f:
  65. f.write(engine_bytes)
  66. print("trt engine saved to {}".format(trt_path))
  67. engine = self.trt_runtime.deserialize_cuda_engine(engine_bytes)
  68. return engine
  69. def _load_trt(self, trt_path: str):
  70. logging.info("Found trt engine, loading...")
  71. with open(trt_path, "rb") as f:
  72. engine_bytes = f.read()
  73. engine = self.trt_runtime.deserialize_cuda_engine(engine_bytes)
  74. return engine
  75. def forward(
  76. self,
  77. x: torch.Tensor,
  78. mask: torch.Tensor,
  79. mu: torch.Tensor,
  80. t: torch.Tensor,
  81. spks: torch.Tensor,
  82. cond: torch.Tensor,
  83. ):
  84. with self.lock:
  85. with torch.cuda.device(self.device):
  86. self.context.set_input_shape("x", (2, 80, x.size(2)))
  87. self.context.set_input_shape("mask", (2, 1, x.size(2)))
  88. self.context.set_input_shape("mu", (2, 80, x.size(2)))
  89. self.context.set_input_shape("t", (2,))
  90. self.context.set_input_shape("spks", (2, 80))
  91. self.context.set_input_shape("cond", (2, 80, x.size(2)))
  92. # run trt engine
  93. self.context.execute_v2(
  94. [
  95. x.contiguous().data_ptr(),
  96. mask.contiguous().data_ptr(),
  97. mu.contiguous().data_ptr(),
  98. t.contiguous().data_ptr(),
  99. spks.contiguous().data_ptr(),
  100. cond.contiguous().data_ptr(),
  101. x.data_ptr(),
  102. ]
  103. )
  104. return x
  105. def __call__(
  106. self,
  107. x: torch.Tensor,
  108. mask: torch.Tensor,
  109. mu: torch.Tensor,
  110. t: torch.Tensor,
  111. spks: torch.Tensor,
  112. cond: torch.Tensor,
  113. ):
  114. return self.forward(x, mask, mu, t, spks, cond)