| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- import json
- import os
- import logging
- import queue
- import torch
- import numpy as np
- from torch.utils.dlpack import to_dlpack
- import triton_python_backend_utils as pb_utils
- from hyperpyyaml import load_hyperpyyaml
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- class TrtContextWrapper:
- def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
- self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
- self.trt_engine = trt_engine
- self.device = device
- for _ in range(trt_concurrent):
- trt_context = trt_engine.create_execution_context()
- trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
- assert trt_context is not None
- self.trt_context_pool.put([trt_context, trt_stream])
- def acquire_estimator(self):
- return self.trt_context_pool.get(), self.trt_engine
- def release_estimator(self, context, stream):
- self.trt_context_pool.put([context, stream])
- def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocast_mode=False):
- import tensorrt as trt
- logging.info("Converting onnx to trt...")
- if autocast_mode:
- network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
- else:
- network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
- trt_logger = trt.Logger(trt.Logger.INFO)
- builder = trt.Builder(trt_logger)
- network = builder.create_network(network_flags)
- parser = trt.OnnxParser(network, trt_logger)
- config = builder.create_builder_config()
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)
- if not autocast_mode and fp16:
- config.set_flag(trt.BuilderFlag.FP16)
- profile = builder.create_optimization_profile()
- with open(onnx_model, "rb") as f:
- if not parser.parse(f.read()):
- for error in range(parser.num_errors):
- print(parser.get_error(error))
- raise ValueError(f'failed to parse {onnx_model}')
- for i in range(len(trt_kwargs['input_names'])):
- profile.set_shape(trt_kwargs['input_names'][i],
- trt_kwargs['min_shape'][i],
- trt_kwargs['opt_shape'][i],
- trt_kwargs['max_shape'][i])
- if not autocast_mode:
- tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
- for i in range(network.num_inputs):
- network.get_input(i).dtype = tensor_dtype
- for i in range(network.num_outputs):
- network.get_output(i).dtype = tensor_dtype
- config.add_optimization_profile(profile)
- engine_bytes = builder.build_serialized_network(network, config)
- with open(trt_model, "wb") as f:
- f.write(engine_bytes)
- logging.info("Successfully converted onnx to trt")
- torch.set_num_threads(1)
- class TritonPythonModel:
- """Triton Python model for CosyVoice3 token2wav (flow-only, stateless).
- Converts speech tokens to mel spectrogram using the CausalMaskedDiffWithDiT flow model.
- """
- def initialize(self, args):
- parameters = json.loads(args['model_config'])['parameters']
- model_params = {k: v["string_value"] for k, v in parameters.items()}
- model_dir = model_params["model_dir"]
- self.device = torch.device("cuda")
- # Load flow model from cosyvoice3.yaml
- with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
- configs = load_hyperpyyaml(f, overrides={
- 'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
- })
- self.flow = configs['flow']
- self.fp16 = True
- self.flow.half()
- self.flow.load_state_dict(
- torch.load(os.path.join(model_dir, 'flow.pt'),
- map_location='cpu', weights_only=True),
- strict=True
- )
- self.flow.to(self.device).eval()
- # TRT acceleration for flow decoder estimator
- self.load_trt(model_dir)
- self.token_mel_ratio = self.flow.token_mel_ratio
- logger.info(f"Token2wav (flow-only) initialized, token_mel_ratio={self.token_mel_ratio}")
- def load_trt(self, model_dir, trt_concurrent=1):
- device_id = torch.cuda.current_device()
- onnx_path = os.path.join(model_dir, 'flow.decoder.estimator.autocast_fp16.onnx')
- trt_path = os.path.join(model_dir, f'flow.decoder.estimator.autocast_fp16.{device_id}.plan')
- if not os.path.exists(trt_path) or os.path.getsize(trt_path) == 0:
- trt_kwargs = self.get_trt_kwargs()
- convert_onnx_to_trt(trt_path, trt_kwargs, onnx_path,
- fp16=True, autocast_mode=True)
- del self.flow.decoder.estimator
- import tensorrt as trt
- with open(trt_path, 'rb') as f:
- estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
- assert estimator_engine is not None, f'failed to load trt {trt_path}'
- self.flow.decoder.estimator = TrtContextWrapper(
- estimator_engine, trt_concurrent=trt_concurrent, device=str(self.device))
- def get_trt_kwargs(self):
- min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
- opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
- max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
- input_names = ["x", "mask", "mu", "cond"]
- return {'min_shape': min_shape, 'opt_shape': opt_shape,
- 'max_shape': max_shape, 'input_names': input_names}
- def execute(self, requests):
- responses = []
- for req_idx, request in enumerate(requests):
- target_speech_tokens = pb_utils.get_input_tensor_by_name(
- request, "target_speech_tokens")
- target_speech_tokens = torch.utils.dlpack.from_dlpack(
- target_speech_tokens.to_dlpack()).to(self.device)
- if target_speech_tokens.dim() == 1:
- target_speech_tokens = target_speech_tokens.unsqueeze(0)
- # Optional inputs
- prompt_speech_tokens_pb = pb_utils.get_input_tensor_by_name(
- request, "prompt_speech_tokens")
- if prompt_speech_tokens_pb is not None:
- prompt_speech_tokens = torch.utils.dlpack.from_dlpack(
- prompt_speech_tokens_pb.to_dlpack()).to(self.device)
- if prompt_speech_tokens.dim() == 1:
- prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
- prompt_speech_feat = pb_utils.get_input_tensor_by_name(
- request, "prompt_speech_feat")
- prompt_speech_feat = torch.utils.dlpack.from_dlpack(
- prompt_speech_feat.to_dlpack()).to(self.device)
- if prompt_speech_feat.dim() == 2:
- prompt_speech_feat = prompt_speech_feat.unsqueeze(0) # [T, 80] -> [1, T, 80]
- prompt_spk_embedding = pb_utils.get_input_tensor_by_name(
- request, "prompt_spk_embedding")
- prompt_spk_embedding = torch.utils.dlpack.from_dlpack(
- prompt_spk_embedding.to_dlpack()).to(self.device)
- if prompt_spk_embedding.dim() == 1:
- prompt_spk_embedding = prompt_spk_embedding.unsqueeze(0)
- else:
- raise ValueError("prompt_speech_tokens is required for CosyVoice3 token2wav")
- token_offset_pb = pb_utils.get_input_tensor_by_name(request, "token_offset")
- finalize_pb = pb_utils.get_input_tensor_by_name(request, "finalize")
- token_offset = token_offset_pb.as_numpy().item() if token_offset_pb is not None else None
- finalize = finalize_pb.as_numpy().item() if finalize_pb is not None else True
- streaming = not finalize
- with torch.no_grad(), torch.cuda.amp.autocast(self.fp16):
- mel, _ = self.flow.inference(
- token=target_speech_tokens,
- token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(self.device),
- prompt_token=prompt_speech_tokens,
- prompt_token_len=torch.tensor([prompt_speech_tokens.shape[1]], dtype=torch.int32).to(self.device),
- prompt_feat=prompt_speech_feat,
- prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
- embedding=prompt_spk_embedding,
- streaming=streaming,
- finalize=finalize,
- )
- # Slice mel from token_offset if provided
- if token_offset is not None:
- mel = mel[:, :, token_offset * self.token_mel_ratio:]
- # Output mel as [80, T] (squeeze batch dim for Triton)
- mel_out = mel.squeeze(0).float() # [80, T]
- mel_out = mel_out.cpu() # otherwise, dlpack bug
- mel_tensor = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel_out))
- inference_response = pb_utils.InferenceResponse(output_tensors=[mel_tensor])
- responses.append(inference_response)
- return responses
|