model.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import torch
  28. from torch.utils.dlpack import to_dlpack
  29. import triton_python_backend_utils as pb_utils
  30. import os
  31. import numpy as np
  32. import torchaudio.compliance.kaldi as kaldi
  33. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  34. from cosyvoice.utils.common import TrtContextWrapper
  35. import onnxruntime
  36. class TritonPythonModel:
  37. """Triton Python model for audio tokenization.
  38. This model takes reference audio input and extracts semantic tokens
  39. using s3tokenizer.
  40. """
  41. def initialize(self, args):
  42. """Initialize the model.
  43. Args:
  44. args: Dictionary containing model configuration
  45. """
  46. # Parse model parameters
  47. parameters = json.loads(args['model_config'])['parameters']
  48. model_params = {k: v["string_value"] for k, v in parameters.items()}
  49. self.device = torch.device("cuda")
  50. model_dir = model_params["model_dir"]
  51. gpu = "l20"
  52. enable_trt = True
  53. if enable_trt:
  54. self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
  55. f'{model_dir}/campplus.onnx',
  56. 1,
  57. False)
  58. else:
  59. campplus_model = f'{model_dir}/campplus.onnx'
  60. option = onnxruntime.SessionOptions()
  61. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  62. option.intra_op_num_threads = 1
  63. self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
  64. def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
  65. if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
  66. trt_kwargs = self.get_spk_trt_kwargs()
  67. convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
  68. import tensorrt as trt
  69. with open(spk_model, 'rb') as f:
  70. spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  71. assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
  72. self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
  73. def get_spk_trt_kwargs(self):
  74. min_shape = [(1, 4, 80)]
  75. opt_shape = [(1, 500, 80)]
  76. max_shape = [(1, 3000, 80)]
  77. input_names = ["input"]
  78. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  79. def _extract_spk_embedding(self, speech):
  80. feat = kaldi.fbank(speech,
  81. num_mel_bins=80,
  82. dither=0,
  83. sample_frequency=16000)
  84. spk_feat = feat - feat.mean(dim=0, keepdim=True)
  85. if isinstance(self.spk_model, onnxruntime.InferenceSession):
  86. embedding = self.spk_model.run(
  87. None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
  88. )[0].flatten().tolist()
  89. embedding = torch.tensor([embedding]).to(self.device)
  90. else:
  91. [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
  92. # NOTE need to synchronize when switching stream
  93. with torch.cuda.device(self.device):
  94. torch.cuda.current_stream().synchronize()
  95. spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
  96. batch_size = spk_feat.size(0)
  97. with stream:
  98. spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
  99. embedding = torch.empty((batch_size, 192), device=spk_feat.device)
  100. data_ptrs = [spk_feat.contiguous().data_ptr(),
  101. embedding.contiguous().data_ptr()]
  102. for i, j in enumerate(data_ptrs):
  103. spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
  104. # run trt engine
  105. assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
  106. torch.cuda.current_stream().synchronize()
  107. self.spk_model.release_estimator(spk_model, stream)
  108. return embedding.half()
  109. def execute(self, requests):
  110. """Execute inference on the batched requests.
  111. Args:
  112. requests: List of inference requests
  113. Returns:
  114. List of inference responses containing tokenized outputs
  115. """
  116. responses = []
  117. # Process each request in batch
  118. for request in requests:
  119. # Extract input tensors
  120. wav_array = pb_utils.get_input_tensor_by_name(
  121. request, "reference_wav").as_numpy()
  122. wav_array = torch.from_numpy(wav_array).to(self.device)
  123. embedding = self._extract_spk_embedding(wav_array)
  124. prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
  125. "prompt_spk_embedding", to_dlpack(embedding))
  126. inference_response = pb_utils.InferenceResponse(
  127. output_tensors=[prompt_spk_embedding_tensor])
  128. responses.append(inference_response)
  129. return responses