model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import os
  28. import logging
  29. import torch
  30. from torch.utils.dlpack import to_dlpack
  31. from torch.nn import functional as F
  32. import triton_python_backend_utils as pb_utils
  33. from hyperpyyaml import load_hyperpyyaml
  34. from cosyvoice.utils.common import fade_in_out
  35. from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
  36. from cosyvoice.utils.common import TrtContextWrapper
  37. from collections import defaultdict
  38. import numpy as np
  39. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  40. logger = logging.getLogger(__name__)
  41. ORIGINAL_VOCAB_SIZE = 151663
  42. torch.set_num_threads(1)
  43. class CosyVoice2:
  44. def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
  45. self.model_dir = model_dir
  46. self.fp16 = fp16
  47. hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
  48. if not os.path.exists(hyper_yaml_path):
  49. raise ValueError('{} not found!'.format(hyper_yaml_path))
  50. with open(hyper_yaml_path, 'r') as f:
  51. configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
  52. self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
  53. self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
  54. if load_jit:
  55. self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
  56. if load_trt:
  57. self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
  58. '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
  59. trt_concurrent,
  60. self.fp16)
  61. class CosyVoice2Model:
  62. def __init__(self,
  63. flow: torch.nn.Module,
  64. hift: torch.nn.Module,
  65. fp16: bool = False,
  66. device: str = 'cuda'):
  67. self.device = device
  68. self.flow = flow
  69. self.hift = hift
  70. self.fp16 = fp16
  71. if self.fp16 is True:
  72. self.flow.half()
  73. # streaming tts config
  74. self.token_hop_len = 25
  75. self.mel_cache_len = 8
  76. self.source_cache_len = int(self.mel_cache_len * 480)
  77. self.speech_window = np.hamming(2 * self.source_cache_len)
  78. self.hift_cache_dict = defaultdict(lambda: None)
  79. def load_jit(self, flow_encoder_model):
  80. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  81. self.flow.encoder = flow_encoder
  82. def load(self, flow_model, hift_model):
  83. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  84. self.flow.to(self.device).eval()
  85. # in case hift_model is a hifigan model
  86. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  87. self.hift.load_state_dict(hift_state_dict, strict=True)
  88. self.hift.to(self.device).eval()
  89. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
  90. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  91. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  92. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  93. del self.flow.decoder.estimator
  94. import tensorrt as trt
  95. with open(flow_decoder_estimator_model, 'rb') as f:
  96. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  97. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  98. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  99. def get_trt_kwargs(self):
  100. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  101. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  102. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  103. input_names = ["x", "mask", "mu", "cond"]
  104. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  105. def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
  106. with torch.cuda.amp.autocast(self.fp16):
  107. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  108. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  109. prompt_token=prompt_token.to(self.device),
  110. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  111. prompt_feat=prompt_feat.to(self.device),
  112. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  113. embedding=embedding.to(self.device),
  114. streaming=stream,
  115. finalize=finalize)
  116. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  117. # append hift cache
  118. if self.hift_cache_dict[uuid] is not None:
  119. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  120. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  121. else:
  122. hift_cache_source = torch.zeros(1, 1, 0)
  123. # keep overlap mel and hift cache
  124. if finalize is False:
  125. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  126. if self.hift_cache_dict[uuid] is not None:
  127. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  128. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  129. 'source': tts_source[:, :, -self.source_cache_len:],
  130. 'speech': tts_speech[:, -self.source_cache_len:]}
  131. tts_speech = tts_speech[:, :-self.source_cache_len]
  132. else:
  133. if speed != 1.0:
  134. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  135. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  136. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  137. if self.hift_cache_dict[uuid] is not None:
  138. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  139. return tts_speech
  140. class TritonPythonModel:
  141. """Triton Python model for vocoder.
  142. This model takes global and semantic tokens as input and generates audio waveforms
  143. using the BiCodec vocoder.
  144. """
  145. def initialize(self, args):
  146. """Initialize the model.
  147. Args:
  148. args: Dictionary containing model configuration
  149. """
  150. # Parse model parameters
  151. parameters = json.loads(args['model_config'])['parameters']
  152. model_params = {key: value["string_value"] for key, value in parameters.items()}
  153. model_dir = model_params["model_dir"]
  154. # Initialize device and vocoder
  155. self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  156. logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
  157. self.token2wav_model = CosyVoice2(
  158. model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
  159. )
  160. spk_info_path = os.path.join(model_dir, "spk2info.pt")
  161. if not os.path.exists(spk_info_path):
  162. raise ValueError(f"spk2info.pt not found in {model_dir}")
  163. spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
  164. self.default_spk_info = spk_info["001"]
  165. logger.info("Token2Wav initialized successfully")
  166. def execute(self, requests):
  167. """Execute inference on the batched requests.
  168. Args:
  169. requests: List of inference requests
  170. Returns:
  171. List of inference responses containing generated waveforms
  172. """
  173. responses = []
  174. # Process each request in batch
  175. for request in requests:
  176. target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
  177. target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
  178. prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
  179. if prompt_speech_tokens_tensor is not None:
  180. prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
  181. prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
  182. prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
  183. prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
  184. prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
  185. prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
  186. prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
  187. else:
  188. prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
  189. prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
  190. prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
  191. # shift the speech tokens according to the original vocab size
  192. target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
  193. # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
  194. token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
  195. if token_offset is not None:
  196. token_offset = token_offset.as_numpy().item()
  197. finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
  198. if not finalize:
  199. stream = True
  200. else:
  201. stream = False
  202. request_id = request.request_id()
  203. audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
  204. prompt_token=prompt_speech_tokens,
  205. prompt_feat=prompt_speech_feat,
  206. embedding=prompt_spk_embedding,
  207. token_offset=token_offset,
  208. uuid=request_id,
  209. stream=stream,
  210. finalize=finalize)
  211. if finalize:
  212. self.token2wav_model.model.hift_cache_dict.pop(request_id)
  213. else:
  214. tts_mel, _ = self.token2wav_model.model.flow.inference(
  215. token=target_speech_tokens,
  216. token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
  217. self.device
  218. ),
  219. prompt_token=prompt_speech_tokens,
  220. prompt_token_len=torch.tensor(
  221. [prompt_speech_tokens.shape[1]], dtype=torch.int32
  222. ).to(self.device),
  223. prompt_feat=prompt_speech_feat,
  224. prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
  225. embedding=prompt_spk_embedding,
  226. streaming=False,
  227. finalize=True,
  228. )
  229. audio_hat, _ = self.token2wav_model.model.hift.inference(
  230. speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
  231. )
  232. generated_wave = audio_hat.squeeze(0).cpu().numpy()
  233. wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
  234. inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
  235. responses.append(inference_response)
  236. return responses