model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import os
  28. import logging
  29. from typing import List, Dict
  30. import torch
  31. from torch.utils.dlpack import to_dlpack
  32. from torch.nn import functional as F
  33. import triton_python_backend_utils as pb_utils
  34. from hyperpyyaml import load_hyperpyyaml
  35. from cosyvoice.utils.common import fade_in_out
  36. from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
  37. from cosyvoice.utils.common import TrtContextWrapper
  38. from collections import defaultdict
  39. import numpy as np
  40. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  41. logger = logging.getLogger(__name__)
  42. ORIGINAL_VOCAB_SIZE = 151663
  43. torch.set_num_threads(1)
  44. class CosyVoice2:
  45. def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
  46. self.model_dir = model_dir
  47. self.fp16 = fp16
  48. hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
  49. if not os.path.exists(hyper_yaml_path):
  50. raise ValueError('{} not found!'.format(hyper_yaml_path))
  51. with open(hyper_yaml_path, 'r') as f:
  52. configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
  53. self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
  54. self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
  55. if load_jit:
  56. self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
  57. if load_trt:
  58. self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
  59. '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
  60. trt_concurrent,
  61. self.fp16)
  62. class CosyVoice2Model:
  63. def __init__(self,
  64. flow: torch.nn.Module,
  65. hift: torch.nn.Module,
  66. fp16: bool = False,
  67. device: str = 'cuda'):
  68. self.device = device
  69. self.flow = flow
  70. self.hift = hift
  71. self.fp16 = fp16
  72. if self.fp16 is True:
  73. self.flow.half()
  74. # streaming tts config
  75. self.token_hop_len = 25
  76. self.mel_cache_len = 8
  77. self.source_cache_len = int(self.mel_cache_len * 480)
  78. self.speech_window = np.hamming(2 * self.source_cache_len)
  79. self.hift_cache_dict = defaultdict(lambda: None)
  80. def load_jit(self, flow_encoder_model):
  81. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  82. self.flow.encoder = flow_encoder
  83. def load(self, flow_model, hift_model):
  84. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  85. self.flow.to(self.device).eval()
  86. # in case hift_model is a hifigan model
  87. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  88. self.hift.load_state_dict(hift_state_dict, strict=True)
  89. self.hift.to(self.device).eval()
  90. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
  91. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  92. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  93. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  94. del self.flow.decoder.estimator
  95. import tensorrt as trt
  96. with open(flow_decoder_estimator_model, 'rb') as f:
  97. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  98. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  99. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  100. def get_trt_kwargs(self):
  101. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  102. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  103. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  104. input_names = ["x", "mask", "mu", "cond"]
  105. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  106. def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
  107. with torch.cuda.amp.autocast(self.fp16):
  108. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  109. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  110. prompt_token=prompt_token.to(self.device),
  111. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  112. prompt_feat=prompt_feat.to(self.device),
  113. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  114. embedding=embedding.to(self.device),
  115. streaming=stream,
  116. finalize=finalize)
  117. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  118. # append hift cache
  119. if self.hift_cache_dict[uuid] is not None:
  120. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  121. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  122. else:
  123. hift_cache_source = torch.zeros(1, 1, 0)
  124. # keep overlap mel and hift cache
  125. if finalize is False:
  126. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  127. if self.hift_cache_dict[uuid] is not None:
  128. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  129. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  130. 'source': tts_source[:, :, -self.source_cache_len:],
  131. 'speech': tts_speech[:, -self.source_cache_len:]}
  132. tts_speech = tts_speech[:, :-self.source_cache_len]
  133. else:
  134. if speed != 1.0:
  135. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  136. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  137. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  138. if self.hift_cache_dict[uuid] is not None:
  139. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  140. return tts_speech
  141. class TritonPythonModel:
  142. """Triton Python model for vocoder.
  143. This model takes global and semantic tokens as input and generates audio waveforms
  144. using the BiCodec vocoder.
  145. """
  146. def initialize(self, args):
  147. """Initialize the model.
  148. Args:
  149. args: Dictionary containing model configuration
  150. """
  151. # Parse model parameters
  152. parameters = json.loads(args['model_config'])['parameters']
  153. model_params = {key: value["string_value"] for key, value in parameters.items()}
  154. model_dir = model_params["model_dir"]
  155. # Initialize device and vocoder
  156. self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  157. logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
  158. self.token2wav_model = CosyVoice2(
  159. model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
  160. )
  161. spk_info_path = os.path.join(model_dir, "spk2info.pt")
  162. if not os.path.exists(spk_info_path):
  163. raise ValueError(f"spk2info.pt not found in {model_dir}")
  164. spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
  165. self.default_spk_info = spk_info["001"]
  166. logger.info("Token2Wav initialized successfully")
  167. def execute(self, requests):
  168. """Execute inference on the batched requests.
  169. Args:
  170. requests: List of inference requests
  171. Returns:
  172. List of inference responses containing generated waveforms
  173. """
  174. responses = []
  175. # Process each request in batch
  176. for request in requests:
  177. target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
  178. target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
  179. prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
  180. if prompt_speech_tokens_tensor is not None:
  181. prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
  182. prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
  183. prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
  184. prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
  185. prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
  186. prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
  187. prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
  188. else:
  189. prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
  190. prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
  191. prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
  192. # shift the speech tokens according to the original vocab size
  193. target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
  194. # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
  195. token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
  196. if token_offset is not None:
  197. token_offset = token_offset.as_numpy().item()
  198. finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
  199. if not finalize:
  200. stream = True
  201. else:
  202. stream = False
  203. request_id = request.request_id()
  204. audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
  205. prompt_token=prompt_speech_tokens,
  206. prompt_feat=prompt_speech_feat,
  207. embedding=prompt_spk_embedding,
  208. token_offset=token_offset,
  209. uuid=request_id,
  210. stream=stream,
  211. finalize=finalize)
  212. if finalize:
  213. self.token2wav_model.model.hift_cache_dict.pop(request_id)
  214. else:
  215. tts_mel, _ = self.token2wav_model.model.flow.inference(
  216. token=target_speech_tokens,
  217. token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
  218. self.device
  219. ),
  220. prompt_token=prompt_speech_tokens,
  221. prompt_token_len=torch.tensor(
  222. [prompt_speech_tokens.shape[1]], dtype=torch.int32
  223. ).to(self.device),
  224. prompt_feat=prompt_speech_feat,
  225. prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
  226. embedding=prompt_spk_embedding,
  227. streaming=False,
  228. finalize=True,
  229. )
  230. audio_hat, _ = self.token2wav_model.model.hift.inference(
  231. speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
  232. )
  233. generated_wave = audio_hat.squeeze(0).cpu().numpy()
  234. wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
  235. inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
  236. responses.append(inference_response)
  237. return responses