model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import os
  28. import logging
  29. from typing import List, Dict
  30. import torch
  31. from torch.utils.dlpack import to_dlpack
  32. import triton_python_backend_utils as pb_utils
  33. from hyperpyyaml import load_hyperpyyaml
  34. from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
  35. from cosyvoice.utils.common import TrtContextWrapper
  36. #import sys
  37. #sys.path.append("/home/scratch.yuekaiz_wwfo_1/tts/cosyvoice/CosyVoice/third_party/Matcha-TTS")
  38. # Configure logging
  39. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  40. logger = logging.getLogger(__name__)
  41. class CosyVoice2:
  42. def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
  43. self.model_dir = model_dir
  44. self.fp16 = fp16
  45. hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
  46. if not os.path.exists(hyper_yaml_path):
  47. raise ValueError('{} not found!'.format(hyper_yaml_path))
  48. with open(hyper_yaml_path, 'r') as f:
  49. configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
  50. self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
  51. self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
  52. if load_jit:
  53. self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
  54. if load_trt:
  55. self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
  56. '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
  57. trt_concurrent,
  58. self.fp16)
  59. class CosyVoice2Model:
  60. def __init__(self,
  61. flow: torch.nn.Module,
  62. hift: torch.nn.Module,
  63. fp16: bool = False):
  64. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  65. self.flow = flow
  66. self.hift = hift
  67. self.fp16 = fp16
  68. if self.fp16 is True:
  69. self.flow.half()
  70. def load_jit(self, flow_encoder_model):
  71. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  72. self.flow.encoder = flow_encoder
  73. def load(self, flow_model, hift_model):
  74. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  75. self.flow.to(self.device).eval()
  76. # in case hift_model is a hifigan model
  77. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  78. self.hift.load_state_dict(hift_state_dict, strict=True)
  79. self.hift.to(self.device).eval()
  80. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
  81. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  82. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  83. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  84. del self.flow.decoder.estimator
  85. import tensorrt as trt
  86. with open(flow_decoder_estimator_model, 'rb') as f:
  87. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  88. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  89. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  90. def get_trt_kwargs(self):
  91. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  92. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  93. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  94. input_names = ["x", "mask", "mu", "cond"]
  95. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  96. class TritonPythonModel:
  97. """Triton Python model for vocoder.
  98. This model takes global and semantic tokens as input and generates audio waveforms
  99. using the BiCodec vocoder.
  100. """
  101. def initialize(self, args):
  102. """Initialize the model.
  103. Args:
  104. args: Dictionary containing model configuration
  105. """
  106. # Parse model parameters
  107. parameters = json.loads(args['model_config'])['parameters']
  108. model_params = {key: value["string_value"] for key, value in parameters.items()}
  109. model_dir = model_params["model_dir"]
  110. # Initialize device and vocoder
  111. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  112. logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
  113. self.token2wav_model = CosyVoice2(
  114. model_dir, load_jit=True, load_trt=True, fp16=True
  115. )
  116. logger.info("Token2Wav initialized successfully")
  117. def execute(self, requests):
  118. """Execute inference on the batched requests.
  119. Args:
  120. requests: List of inference requests
  121. Returns:
  122. List of inference responses containing generated waveforms
  123. """
  124. responses = []
  125. # Process each request in batch
  126. for request in requests:
  127. target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
  128. prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
  129. prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
  130. prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
  131. target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
  132. prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
  133. prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
  134. prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
  135. prompt_speech_tokens = prompt_speech_tokens - 151663
  136. target_speech_tokens = target_speech_tokens - 151663
  137. tts_mel, _ = self.token2wav_model.model.flow.inference(
  138. token=target_speech_tokens,
  139. token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
  140. self.device
  141. ),
  142. prompt_token=prompt_speech_tokens,
  143. prompt_token_len=torch.tensor(
  144. [prompt_speech_tokens.shape[1]], dtype=torch.int32
  145. ).to(self.device),
  146. prompt_feat=prompt_speech_feat,
  147. prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
  148. embedding=prompt_spk_embedding,
  149. streaming=False,
  150. finalize=True,
  151. )
  152. audio_hat, _ = self.token2wav_model.model.hift.inference(
  153. speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
  154. )
  155. generated_wave = audio_hat.squeeze(0).cpu().numpy()
  156. wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
  157. inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
  158. responses.append(inference_response)
  159. return responses