model.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import json
  2. import os
  3. import logging
  4. import queue
  5. import torch
  6. import numpy as np
  7. from torch.utils.dlpack import to_dlpack
  8. import triton_python_backend_utils as pb_utils
  9. from hyperpyyaml import load_hyperpyyaml
  10. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  11. logger = logging.getLogger(__name__)
  12. class TrtContextWrapper:
  13. def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
  14. self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
  15. self.trt_engine = trt_engine
  16. self.device = device
  17. for _ in range(trt_concurrent):
  18. trt_context = trt_engine.create_execution_context()
  19. trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
  20. assert trt_context is not None
  21. self.trt_context_pool.put([trt_context, trt_stream])
  22. def acquire_estimator(self):
  23. return self.trt_context_pool.get(), self.trt_engine
  24. def release_estimator(self, context, stream):
  25. self.trt_context_pool.put([context, stream])
  26. def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocast_mode=False):
  27. import tensorrt as trt
  28. logging.info("Converting onnx to trt...")
  29. if autocast_mode:
  30. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
  31. else:
  32. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  33. trt_logger = trt.Logger(trt.Logger.INFO)
  34. builder = trt.Builder(trt_logger)
  35. network = builder.create_network(network_flags)
  36. parser = trt.OnnxParser(network, trt_logger)
  37. config = builder.create_builder_config()
  38. config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)
  39. if not autocast_mode and fp16:
  40. config.set_flag(trt.BuilderFlag.FP16)
  41. profile = builder.create_optimization_profile()
  42. with open(onnx_model, "rb") as f:
  43. if not parser.parse(f.read()):
  44. for error in range(parser.num_errors):
  45. print(parser.get_error(error))
  46. raise ValueError(f'failed to parse {onnx_model}')
  47. for i in range(len(trt_kwargs['input_names'])):
  48. profile.set_shape(trt_kwargs['input_names'][i],
  49. trt_kwargs['min_shape'][i],
  50. trt_kwargs['opt_shape'][i],
  51. trt_kwargs['max_shape'][i])
  52. if not autocast_mode:
  53. tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
  54. for i in range(network.num_inputs):
  55. network.get_input(i).dtype = tensor_dtype
  56. for i in range(network.num_outputs):
  57. network.get_output(i).dtype = tensor_dtype
  58. config.add_optimization_profile(profile)
  59. engine_bytes = builder.build_serialized_network(network, config)
  60. with open(trt_model, "wb") as f:
  61. f.write(engine_bytes)
  62. logging.info("Successfully converted onnx to trt")
  63. torch.set_num_threads(1)
  64. class TritonPythonModel:
  65. """Triton Python model for CosyVoice3 token2wav (flow-only, stateless).
  66. Converts speech tokens to mel spectrogram using the CausalMaskedDiffWithDiT flow model.
  67. """
  68. def initialize(self, args):
  69. parameters = json.loads(args['model_config'])['parameters']
  70. model_params = {k: v["string_value"] for k, v in parameters.items()}
  71. model_dir = model_params["model_dir"]
  72. self.device = torch.device("cuda")
  73. # Load flow model from cosyvoice3.yaml
  74. with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
  75. configs = load_hyperpyyaml(f, overrides={
  76. 'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
  77. })
  78. self.flow = configs['flow']
  79. self.fp16 = True
  80. self.flow.half()
  81. self.flow.load_state_dict(
  82. torch.load(os.path.join(model_dir, 'flow.pt'),
  83. map_location='cpu', weights_only=True),
  84. strict=True
  85. )
  86. self.flow.to(self.device).eval()
  87. # TRT acceleration for flow decoder estimator
  88. self.load_trt(model_dir)
  89. self.token_mel_ratio = self.flow.token_mel_ratio
  90. logger.info(f"Token2wav (flow-only) initialized, token_mel_ratio={self.token_mel_ratio}")
  91. def load_trt(self, model_dir, trt_concurrent=1):
  92. device_id = torch.cuda.current_device()
  93. onnx_path = os.path.join(model_dir, 'flow.decoder.estimator.autocast_fp16.onnx')
  94. trt_path = os.path.join(model_dir, f'flow.decoder.estimator.autocast_fp16.{device_id}.plan')
  95. if not os.path.exists(trt_path) or os.path.getsize(trt_path) == 0:
  96. trt_kwargs = self.get_trt_kwargs()
  97. convert_onnx_to_trt(trt_path, trt_kwargs, onnx_path,
  98. fp16=True, autocast_mode=True)
  99. del self.flow.decoder.estimator
  100. import tensorrt as trt
  101. with open(trt_path, 'rb') as f:
  102. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  103. assert estimator_engine is not None, f'failed to load trt {trt_path}'
  104. self.flow.decoder.estimator = TrtContextWrapper(
  105. estimator_engine, trt_concurrent=trt_concurrent, device=str(self.device))
  106. def get_trt_kwargs(self):
  107. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  108. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  109. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  110. input_names = ["x", "mask", "mu", "cond"]
  111. return {'min_shape': min_shape, 'opt_shape': opt_shape,
  112. 'max_shape': max_shape, 'input_names': input_names}
  113. def execute(self, requests):
  114. responses = []
  115. for req_idx, request in enumerate(requests):
  116. target_speech_tokens = pb_utils.get_input_tensor_by_name(
  117. request, "target_speech_tokens")
  118. target_speech_tokens = torch.utils.dlpack.from_dlpack(
  119. target_speech_tokens.to_dlpack()).to(self.device)
  120. if target_speech_tokens.dim() == 1:
  121. target_speech_tokens = target_speech_tokens.unsqueeze(0)
  122. # Optional inputs
  123. prompt_speech_tokens_pb = pb_utils.get_input_tensor_by_name(
  124. request, "prompt_speech_tokens")
  125. if prompt_speech_tokens_pb is not None:
  126. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(
  127. prompt_speech_tokens_pb.to_dlpack()).to(self.device)
  128. if prompt_speech_tokens.dim() == 1:
  129. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
  130. prompt_speech_feat = pb_utils.get_input_tensor_by_name(
  131. request, "prompt_speech_feat")
  132. prompt_speech_feat = torch.utils.dlpack.from_dlpack(
  133. prompt_speech_feat.to_dlpack()).to(self.device)
  134. if prompt_speech_feat.dim() == 2:
  135. prompt_speech_feat = prompt_speech_feat.unsqueeze(0) # [T, 80] -> [1, T, 80]
  136. prompt_spk_embedding = pb_utils.get_input_tensor_by_name(
  137. request, "prompt_spk_embedding")
  138. prompt_spk_embedding = torch.utils.dlpack.from_dlpack(
  139. prompt_spk_embedding.to_dlpack()).to(self.device)
  140. if prompt_spk_embedding.dim() == 1:
  141. prompt_spk_embedding = prompt_spk_embedding.unsqueeze(0)
  142. else:
  143. raise ValueError("prompt_speech_tokens is required for CosyVoice3 token2wav")
  144. token_offset_pb = pb_utils.get_input_tensor_by_name(request, "token_offset")
  145. finalize_pb = pb_utils.get_input_tensor_by_name(request, "finalize")
  146. token_offset = token_offset_pb.as_numpy().item() if token_offset_pb is not None else None
  147. finalize = finalize_pb.as_numpy().item() if finalize_pb is not None else True
  148. streaming = not finalize
  149. with torch.no_grad(), torch.cuda.amp.autocast(self.fp16):
  150. mel, _ = self.flow.inference(
  151. token=target_speech_tokens,
  152. token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(self.device),
  153. prompt_token=prompt_speech_tokens,
  154. prompt_token_len=torch.tensor([prompt_speech_tokens.shape[1]], dtype=torch.int32).to(self.device),
  155. prompt_feat=prompt_speech_feat,
  156. prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
  157. embedding=prompt_spk_embedding,
  158. streaming=streaming,
  159. finalize=finalize,
  160. )
  161. # Slice mel from token_offset if provided
  162. if token_offset is not None:
  163. mel = mel[:, :, token_offset * self.token_mel_ratio:]
  164. # Output mel as [80, T] (squeeze batch dim for Triton)
  165. mel_out = mel.squeeze(0).float() # [80, T]
  166. mel_out = mel_out.cpu() # otherwise, dlpack bug
  167. mel_tensor = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel_out))
  168. inference_response = pb_utils.InferenceResponse(output_tensors=[mel_tensor])
  169. responses.append(inference_response)
  170. return responses