model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import os
  28. import logging
  29. from typing import List, Dict
  30. import torch
  31. from torch.utils.dlpack import to_dlpack
  32. from torch.nn import functional as F
  33. import triton_python_backend_utils as pb_utils
  34. from hyperpyyaml import load_hyperpyyaml
  35. from cosyvoice.utils.common import fade_in_out
  36. from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
  37. from cosyvoice.utils.common import TrtContextWrapper
  38. from collections import defaultdict
  39. import numpy as np
  40. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  41. logger = logging.getLogger(__name__)
  42. ORIGINAL_VOCAB_SIZE = 151663
  43. class CosyVoice2:
  44. def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
  45. self.model_dir = model_dir
  46. self.fp16 = fp16
  47. hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
  48. if not os.path.exists(hyper_yaml_path):
  49. raise ValueError('{} not found!'.format(hyper_yaml_path))
  50. with open(hyper_yaml_path, 'r') as f:
  51. configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
  52. self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
  53. self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
  54. if load_jit:
  55. self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
  56. if load_trt:
  57. self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
  58. '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
  59. trt_concurrent,
  60. self.fp16)
  61. class CosyVoice2Model:
  62. def __init__(self,
  63. flow: torch.nn.Module,
  64. hift: torch.nn.Module,
  65. fp16: bool = False):
  66. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  67. self.flow = flow
  68. self.hift = hift
  69. self.fp16 = fp16
  70. if self.fp16 is True:
  71. self.flow.half()
  72. # streaming tts config
  73. self.token_hop_len = 25
  74. self.mel_cache_len = 8
  75. self.source_cache_len = int(self.mel_cache_len * 480)
  76. self.speech_window = np.hamming(2 * self.source_cache_len)
  77. self.hift_cache_dict = defaultdict(lambda: None)
  78. def load_jit(self, flow_encoder_model):
  79. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  80. self.flow.encoder = flow_encoder
  81. def load(self, flow_model, hift_model):
  82. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  83. self.flow.to(self.device).eval()
  84. # in case hift_model is a hifigan model
  85. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  86. self.hift.load_state_dict(hift_state_dict, strict=True)
  87. self.hift.to(self.device).eval()
  88. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
  89. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  90. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  91. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  92. del self.flow.decoder.estimator
  93. import tensorrt as trt
  94. with open(flow_decoder_estimator_model, 'rb') as f:
  95. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  96. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  97. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  98. def get_trt_kwargs(self):
  99. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  100. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  101. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  102. input_names = ["x", "mask", "mu", "cond"]
  103. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  104. def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
  105. with torch.cuda.amp.autocast(self.fp16):
  106. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  107. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  108. prompt_token=prompt_token.to(self.device),
  109. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  110. prompt_feat=prompt_feat.to(self.device),
  111. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  112. embedding=embedding.to(self.device),
  113. streaming=stream,
  114. finalize=finalize)
  115. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  116. # append hift cache
  117. if self.hift_cache_dict[uuid] is not None:
  118. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  119. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  120. else:
  121. hift_cache_source = torch.zeros(1, 1, 0)
  122. # keep overlap mel and hift cache
  123. if finalize is False:
  124. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  125. if self.hift_cache_dict[uuid] is not None:
  126. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  127. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  128. 'source': tts_source[:, :, -self.source_cache_len:],
  129. 'speech': tts_speech[:, -self.source_cache_len:]}
  130. tts_speech = tts_speech[:, :-self.source_cache_len]
  131. else:
  132. if speed != 1.0:
  133. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  134. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  135. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  136. if self.hift_cache_dict[uuid] is not None:
  137. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  138. return tts_speech
  139. class TritonPythonModel:
  140. """Triton Python model for vocoder.
  141. This model takes global and semantic tokens as input and generates audio waveforms
  142. using the BiCodec vocoder.
  143. """
  144. def initialize(self, args):
  145. """Initialize the model.
  146. Args:
  147. args: Dictionary containing model configuration
  148. """
  149. # Parse model parameters
  150. parameters = json.loads(args['model_config'])['parameters']
  151. model_params = {key: value["string_value"] for key, value in parameters.items()}
  152. model_dir = model_params["model_dir"]
  153. # Initialize device and vocoder
  154. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  155. logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
  156. self.token2wav_model = CosyVoice2(
  157. model_dir, load_jit=True, load_trt=True, fp16=True
  158. )
  159. logger.info("Token2Wav initialized successfully")
  160. def execute(self, requests):
  161. """Execute inference on the batched requests.
  162. Args:
  163. requests: List of inference requests
  164. Returns:
  165. List of inference responses containing generated waveforms
  166. """
  167. responses = []
  168. # Process each request in batch
  169. for request in requests:
  170. target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
  171. prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
  172. prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
  173. prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
  174. target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
  175. prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
  176. prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
  177. prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
  178. # shift the speech tokens according to the original vocab size
  179. prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
  180. target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
  181. # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
  182. token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
  183. if token_offset is not None:
  184. token_offset = token_offset.as_numpy().item()
  185. finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
  186. if not finalize:
  187. stream = True
  188. else:
  189. stream = False
  190. request_id = request.request_id()
  191. print(f"token_offset: {token_offset}, finalize: {finalize}, request_id: {request_id}")
  192. audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
  193. prompt_token=prompt_speech_tokens,
  194. prompt_feat=prompt_speech_feat,
  195. embedding=prompt_spk_embedding,
  196. token_offset=token_offset,
  197. uuid=request_id,
  198. stream=stream,
  199. finalize=finalize)
  200. if finalize:
  201. print(f"dict keys: {self.token2wav_model.model.hift_cache_dict.keys()}")
  202. self.token2wav_model.model.hift_cache_dict.pop(request_id)
  203. else:
  204. tts_mel, _ = self.token2wav_model.model.flow.inference(
  205. token=target_speech_tokens,
  206. token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
  207. self.device
  208. ),
  209. prompt_token=prompt_speech_tokens,
  210. prompt_token_len=torch.tensor(
  211. [prompt_speech_tokens.shape[1]], dtype=torch.int32
  212. ).to(self.device),
  213. prompt_feat=prompt_speech_feat,
  214. prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
  215. embedding=prompt_spk_embedding,
  216. streaming=False,
  217. finalize=True,
  218. )
  219. audio_hat, _ = self.token2wav_model.model.hift.inference(
  220. speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
  221. )
  222. generated_wave = audio_hat.squeeze(0).cpu().numpy()
  223. wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
  224. inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
  225. responses.append(inference_response)
  226. return responses