model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import os
  28. import logging
  29. from typing import List, Dict
  30. import torch
  31. from torch.utils.dlpack import to_dlpack
  32. import triton_python_backend_utils as pb_utils
  33. from hyperpyyaml import load_hyperpyyaml
  34. from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
  35. from cosyvoice.utils.common import TrtContextWrapper
  36. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  37. logger = logging.getLogger(__name__)
  38. ORIGINAL_VOCAB_SIZE = 151663
  39. class CosyVoice2:
  40. def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
  41. self.model_dir = model_dir
  42. self.fp16 = fp16
  43. hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
  44. if not os.path.exists(hyper_yaml_path):
  45. raise ValueError('{} not found!'.format(hyper_yaml_path))
  46. with open(hyper_yaml_path, 'r') as f:
  47. configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
  48. self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
  49. self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
  50. if load_jit:
  51. self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
  52. if load_trt:
  53. self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
  54. '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
  55. trt_concurrent,
  56. self.fp16)
  57. class CosyVoice2Model:
  58. def __init__(self,
  59. flow: torch.nn.Module,
  60. hift: torch.nn.Module,
  61. fp16: bool = False):
  62. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  63. self.flow = flow
  64. self.hift = hift
  65. self.fp16 = fp16
  66. if self.fp16 is True:
  67. self.flow.half()
  68. def load_jit(self, flow_encoder_model):
  69. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  70. self.flow.encoder = flow_encoder
  71. def load(self, flow_model, hift_model):
  72. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  73. self.flow.to(self.device).eval()
  74. # in case hift_model is a hifigan model
  75. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  76. self.hift.load_state_dict(hift_state_dict, strict=True)
  77. self.hift.to(self.device).eval()
  78. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
  79. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  80. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  81. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  82. del self.flow.decoder.estimator
  83. import tensorrt as trt
  84. with open(flow_decoder_estimator_model, 'rb') as f:
  85. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  86. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  87. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  88. def get_trt_kwargs(self):
  89. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  90. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  91. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  92. input_names = ["x", "mask", "mu", "cond"]
  93. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  94. class TritonPythonModel:
  95. """Triton Python model for vocoder.
  96. This model takes global and semantic tokens as input and generates audio waveforms
  97. using the BiCodec vocoder.
  98. """
  99. def initialize(self, args):
  100. """Initialize the model.
  101. Args:
  102. args: Dictionary containing model configuration
  103. """
  104. # Parse model parameters
  105. parameters = json.loads(args['model_config'])['parameters']
  106. model_params = {key: value["string_value"] for key, value in parameters.items()}
  107. model_dir = model_params["model_dir"]
  108. # Initialize device and vocoder
  109. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  110. logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
  111. self.token2wav_model = CosyVoice2(
  112. model_dir, load_jit=True, load_trt=True, fp16=True
  113. )
  114. logger.info("Token2Wav initialized successfully")
  115. def execute(self, requests):
  116. """Execute inference on the batched requests.
  117. Args:
  118. requests: List of inference requests
  119. Returns:
  120. List of inference responses containing generated waveforms
  121. """
  122. responses = []
  123. # Process each request in batch
  124. for request in requests:
  125. target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
  126. prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
  127. prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
  128. prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
  129. target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
  130. prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
  131. prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
  132. prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
  133. # shift the speech tokens according to the original vocab size
  134. prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
  135. target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
  136. tts_mel, _ = self.token2wav_model.model.flow.inference(
  137. token=target_speech_tokens,
  138. token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
  139. self.device
  140. ),
  141. prompt_token=prompt_speech_tokens,
  142. prompt_token_len=torch.tensor(
  143. [prompt_speech_tokens.shape[1]], dtype=torch.int32
  144. ).to(self.device),
  145. prompt_feat=prompt_speech_feat,
  146. prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
  147. embedding=prompt_spk_embedding,
  148. streaming=False,
  149. finalize=True,
  150. )
  151. audio_hat, _ = self.token2wav_model.model.hift.inference(
  152. speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
  153. )
  154. generated_wave = audio_hat.squeeze(0).cpu().numpy()
  155. wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
  156. inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
  157. responses.append(inference_response)
  158. return responses