model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import math
  28. import os
  29. import re
  30. from typing import Dict, List, Tuple, Optional, Union
  31. import numpy as np
  32. import torch
  33. from torch.utils.dlpack import from_dlpack, to_dlpack
  34. import triton_python_backend_utils as pb_utils
  35. from transformers import AutoTokenizer
  36. import torchaudio.compliance.kaldi as kaldi
  37. import torchaudio
  38. import onnxruntime
  39. from matcha.utils.audio import mel_spectrogram
  40. class TritonPythonModel:
  41. """Triton Python model for Spark TTS.
  42. This model orchestrates the end-to-end TTS pipeline by coordinating
  43. between audio tokenizer, LLM, and vocoder components.
  44. """
  45. def initialize(self, args):
  46. """Initialize the model.
  47. Args:
  48. args: Dictionary containing model configuration
  49. """
  50. self.logger = pb_utils.Logger
  51. # Parse model parameters
  52. self.model_config = json.loads(args['model_config'])
  53. parameters = self.model_config['parameters']
  54. model_params = {k: v["string_value"] for k, v in parameters.items()}
  55. self.logger.log_info(f"model_params:{model_params}")
  56. # Initialize tokenizer
  57. llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
  58. self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
  59. self.prompt_template = "<|sos|>{input_text}<|task_id|>"
  60. self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
  61. self.device = torch.device("cuda")
  62. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  63. campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
  64. option = onnxruntime.SessionOptions()
  65. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  66. option.intra_op_num_threads = 1
  67. self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
  68. def forward_llm(self, input_ids):
  69. """
  70. Prepares the response from the language model based on the provided
  71. inputs. Creates a `pb_utils.InferenceRequest` object with passed
  72. `llm_request_inputs` to send to a decoupled TensorRTLLM model.
  73. For each response from the language model:
  74. - Checks for errors and raise an exception if any are found.
  75. - Extracts the "output_ids" tensor from the response.
  76. - Determines the finish reason based on the presence of the
  77. end-of-sequence token or reaching the maximum length.
  78. - Appends the generated token IDs to `output_ids`.
  79. - If the finish reason is determined, decodes the output IDs to text
  80. and prepares the final response.
  81. The final response includes the generated text, finish reason,
  82. completion tokens, prompt tokens, and total tokens.
  83. Parameters
  84. ----------
  85. - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
  86. Returns
  87. -------
  88. - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
  89. """
  90. # convert input_ids to numpy, with shape [1, sequence_length]
  91. input_ids = input_ids.cpu().numpy()
  92. max_tokens = 1024
  93. input_dict = {
  94. "request_output_len": np.array([[max_tokens]], dtype=np.int32),
  95. "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
  96. "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
  97. "streaming": np.array([[self.decoupled]], dtype=np.bool_),
  98. "runtime_top_p": np.array([[0.95]], dtype=np.float32),
  99. "runtime_top_k": np.array([[50]], dtype=np.int32),
  100. "temperature": np.array([[0.8]], dtype=np.float32),
  101. "input_ids": input_ids,
  102. "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
  103. }
  104. # Convert inputs to Triton tensors
  105. input_tensor_list = [
  106. pb_utils.Tensor(k, v) for k, v in input_dict.items()
  107. ]
  108. # Create and execute inference request
  109. llm_request = pb_utils.InferenceRequest(
  110. model_name="tensorrt_llm",
  111. requested_output_names=["output_ids", "sequence_length"],
  112. inputs=input_tensor_list,
  113. )
  114. llm_responses = llm_request.exec(decoupled=self.decoupled)
  115. if self.decoupled:
  116. for llm_response in llm_responses:
  117. if llm_response.has_error():
  118. raise pb_utils.TritonModelException(llm_response.error().message())
  119. # Extract and process output
  120. output_ids = pb_utils.get_output_tensor_by_name(
  121. llm_response, "output_ids").as_numpy()
  122. seq_lens = pb_utils.get_output_tensor_by_name(
  123. llm_response, "sequence_length").as_numpy()
  124. # Get actual output IDs up to the sequence length
  125. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  126. yield actual_output_ids
  127. else:
  128. llm_response = llm_responses
  129. if llm_response.has_error():
  130. raise pb_utils.TritonModelException(llm_response.error().message())
  131. # Extract and process output
  132. output_ids = pb_utils.get_output_tensor_by_name(
  133. llm_response, "output_ids").as_numpy()
  134. seq_lens = pb_utils.get_output_tensor_by_name(
  135. llm_response, "sequence_length").as_numpy()
  136. # Get actual output IDs up to the sequence length
  137. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  138. yield actual_output_ids
  139. def forward_audio_tokenizer(self, wav, wav_len):
  140. """Forward pass through the audio tokenizer component.
  141. Args:
  142. wav: Input waveform tensor
  143. wav_len: Waveform length tensor
  144. Returns:
  145. Tuple of global and semantic tokens
  146. """
  147. inference_request = pb_utils.InferenceRequest(
  148. model_name='audio_tokenizer',
  149. requested_output_names=['prompt_speech_tokens'],
  150. inputs=[wav, wav_len]
  151. )
  152. inference_response = inference_request.exec()
  153. if inference_response.has_error():
  154. raise pb_utils.TritonModelException(inference_response.error().message())
  155. # Extract and convert output tensors
  156. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
  157. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  158. return prompt_speech_tokens
  159. def forward_token2wav(
  160. self,
  161. prompt_speech_tokens: torch.Tensor,
  162. prompt_speech_feat: torch.Tensor,
  163. prompt_spk_embedding: torch.Tensor,
  164. target_speech_tokens: torch.Tensor) -> torch.Tensor:
  165. """Forward pass through the vocoder component.
  166. Args:
  167. prompt_speech_tokens: Prompt speech tokens tensor
  168. prompt_speech_feat: Prompt speech feat tensor
  169. prompt_spk_embedding: Prompt spk embedding tensor
  170. target_speech_tokens: Target speech tokens tensor
  171. Returns:
  172. Generated waveform tensor
  173. """
  174. prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
  175. prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
  176. prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
  177. target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
  178. # Create and execute inference request
  179. inference_request = pb_utils.InferenceRequest(
  180. model_name='token2wav',
  181. requested_output_names=['waveform'],
  182. inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
  183. )
  184. inference_response = inference_request.exec()
  185. if inference_response.has_error():
  186. raise pb_utils.TritonModelException(inference_response.error().message())
  187. # Extract and convert output waveform
  188. waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
  189. waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
  190. return waveform
  191. def parse_input(self, text, prompt_text, prompt_speech_tokens):
  192. total_text = f"{prompt_text}{text}"
  193. prompt = self.prompt_template.format(input_text=total_text)
  194. input_ids = self.tokenizer.encode(prompt)
  195. input_ids = torch.tensor([input_ids], dtype=torch.int32)
  196. input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
  197. return input_ids
  198. def _extract_spk_embedding(self, speech):
  199. feat = kaldi.fbank(speech,
  200. num_mel_bins=80,
  201. dither=0,
  202. sample_frequency=16000)
  203. feat = feat - feat.mean(dim=0, keepdim=True)
  204. embedding = self.campplus_session.run(None,
  205. {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
  206. embedding = torch.tensor([embedding]).to(self.device).half()
  207. return embedding
  208. def _extract_speech_feat(self, speech):
  209. speech_feat = mel_spectrogram(
  210. speech,
  211. n_fft=1920,
  212. num_mels=80,
  213. sampling_rate=24000,
  214. hop_size=480,
  215. win_size=1920,
  216. fmin=0,
  217. fmax=8000).squeeze(
  218. dim=0).transpose(
  219. 0,
  220. 1).to(
  221. self.device)
  222. speech_feat = speech_feat.unsqueeze(dim=0)
  223. return speech_feat
  224. def execute(self, requests):
  225. """Execute inference on the batched requests.
  226. Args:
  227. requests: List of inference requests
  228. Returns:
  229. List of inference responses containing generated audio
  230. """
  231. responses = []
  232. for request in requests:
  233. # Extract input tensors
  234. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  235. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  236. # Process reference audio through audio tokenizer
  237. prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
  238. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
  239. wav_tensor = wav.as_numpy()
  240. wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
  241. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
  242. speech_feat = self._extract_speech_feat(prompt_speech_resample)
  243. token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
  244. prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
  245. prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
  246. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
  247. reference_text = reference_text[0][0].decode('utf-8')
  248. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  249. target_text = target_text[0][0].decode('utf-8')
  250. # Prepare prompt for LLM
  251. input_ids = self.parse_input(
  252. text=target_text,
  253. prompt_text=reference_text,
  254. prompt_speech_tokens=prompt_speech_tokens,
  255. )
  256. # Generate semantic tokens with LLM
  257. generated_ids_iter = self.forward_llm(input_ids)
  258. if self.decoupled:
  259. response_sender = request.get_response_sender()
  260. request_id = request.request_id()
  261. generated_ids = []
  262. for generated_id in generated_ids_iter:
  263. # convert the numpy array into a int32 tensor
  264. generated_id = generated_id.tolist()
  265. if len(generated_id) > 0:
  266. assert len(generated_id) == 1, "Generated ID is not a single integer"
  267. generated_ids.append(generated_id[0])
  268. generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
  269. prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
  270. audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
  271. # Prepare response
  272. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
  273. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  274. response_sender.send(inference_response)
  275. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  276. self.logger.log_info("send tritonserver_response_complete_final to end")
  277. else:
  278. generated_ids = next(generated_ids_iter)
  279. generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
  280. if generated_ids is None or len(generated_ids) == 0:
  281. raise pb_utils.TritonModelException("Generated IDs is None or empty")
  282. prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
  283. audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
  284. # Prepare response
  285. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
  286. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  287. responses.append(inference_response)
  288. if not self.decoupled:
  289. return responses