model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import math
  28. import os
  29. import re
  30. from typing import Dict, List, Tuple, Optional, Union
  31. import numpy as np
  32. import torch
  33. from torch.utils.dlpack import from_dlpack, to_dlpack
  34. import triton_python_backend_utils as pb_utils
  35. from transformers import AutoTokenizer
  36. import torchaudio
  37. from matcha.utils.audio import mel_spectrogram
  38. class TritonPythonModel:
  39. """Triton Python model for Spark TTS.
  40. This model orchestrates the end-to-end TTS pipeline by coordinating
  41. between audio tokenizer, LLM, and vocoder components.
  42. """
  43. def initialize(self, args):
  44. """Initialize the model.
  45. Args:
  46. args: Dictionary containing model configuration
  47. """
  48. self.logger = pb_utils.Logger
  49. # Parse model parameters
  50. self.model_config = json.loads(args['model_config'])
  51. parameters = self.model_config['parameters']
  52. model_params = {k: v["string_value"] for k, v in parameters.items()}
  53. self.logger.log_info(f"model_params:{model_params}")
  54. # Initialize tokenizer
  55. llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
  56. self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
  57. self.prompt_template = "<|sos|>{input_text}<|task_id|>"
  58. self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
  59. self.device = torch.device("cuda")
  60. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  61. def forward_llm(self, input_ids):
  62. """
  63. Prepares the response from the language model based on the provided
  64. inputs. Creates a `pb_utils.InferenceRequest` object with passed
  65. `llm_request_inputs` to send to a decoupled TensorRTLLM model.
  66. For each response from the language model:
  67. - Checks for errors and raise an exception if any are found.
  68. - Extracts the "output_ids" tensor from the response.
  69. - Determines the finish reason based on the presence of the
  70. end-of-sequence token or reaching the maximum length.
  71. - Appends the generated token IDs to `output_ids`.
  72. - If the finish reason is determined, decodes the output IDs to text
  73. and prepares the final response.
  74. The final response includes the generated text, finish reason,
  75. completion tokens, prompt tokens, and total tokens.
  76. Parameters
  77. ----------
  78. - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
  79. Returns
  80. -------
  81. - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
  82. """
  83. # convert input_ids to numpy, with shape [1, sequence_length]
  84. input_ids = input_ids.cpu().numpy()
  85. max_tokens = 1024
  86. input_dict = {
  87. "request_output_len": np.array([[max_tokens]], dtype=np.int32),
  88. "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
  89. "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
  90. "streaming": np.array([[self.decoupled]], dtype=np.bool_),
  91. "runtime_top_p": np.array([[0.95]], dtype=np.float32),
  92. "runtime_top_k": np.array([[50]], dtype=np.int32),
  93. "temperature": np.array([[0.8]], dtype=np.float32),
  94. "repetition_penalty": np.array([[1.1]], dtype=np.float32),
  95. "input_ids": input_ids,
  96. "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
  97. }
  98. # Convert inputs to Triton tensors
  99. input_tensor_list = [
  100. pb_utils.Tensor(k, v) for k, v in input_dict.items()
  101. ]
  102. # Create and execute inference request
  103. llm_request = pb_utils.InferenceRequest(
  104. model_name="tensorrt_llm",
  105. requested_output_names=["output_ids", "sequence_length"],
  106. inputs=input_tensor_list,
  107. )
  108. llm_responses = llm_request.exec(decoupled=self.decoupled)
  109. if self.decoupled:
  110. for llm_response in llm_responses:
  111. if llm_response.has_error():
  112. raise pb_utils.TritonModelException(llm_response.error().message())
  113. # Extract and process output
  114. output_ids = pb_utils.get_output_tensor_by_name(
  115. llm_response, "output_ids").as_numpy()
  116. seq_lens = pb_utils.get_output_tensor_by_name(
  117. llm_response, "sequence_length").as_numpy()
  118. # Get actual output IDs up to the sequence length
  119. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  120. print(f"actual_output_ids: {actual_output_ids}")
  121. yield actual_output_ids
  122. else:
  123. llm_response = llm_responses
  124. if llm_response.has_error():
  125. raise pb_utils.TritonModelException(llm_response.error().message())
  126. # Extract and process output
  127. output_ids = pb_utils.get_output_tensor_by_name(
  128. llm_response, "output_ids").as_numpy()
  129. seq_lens = pb_utils.get_output_tensor_by_name(
  130. llm_response, "sequence_length").as_numpy()
  131. # Get actual output IDs up to the sequence length
  132. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  133. yield actual_output_ids
  134. def forward_audio_tokenizer(self, wav, wav_len):
  135. """Forward pass through the audio tokenizer component.
  136. Args:
  137. wav: Input waveform tensor
  138. wav_len: Waveform length tensor
  139. Returns:
  140. Tuple of global and semantic tokens
  141. """
  142. inference_request = pb_utils.InferenceRequest(
  143. model_name='audio_tokenizer',
  144. requested_output_names=['prompt_speech_tokens'],
  145. inputs=[wav, wav_len]
  146. )
  147. inference_response = inference_request.exec()
  148. if inference_response.has_error():
  149. raise pb_utils.TritonModelException(inference_response.error().message())
  150. # Extract and convert output tensors
  151. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
  152. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  153. return prompt_speech_tokens
  154. def forward_speaker_embedding(self, wav):
  155. """Forward pass through the speaker embedding component.
  156. Args:
  157. wav: Input waveform tensor
  158. Returns:
  159. Prompt speaker embedding tensor
  160. """
  161. inference_request = pb_utils.InferenceRequest(
  162. model_name='speaker_embedding',
  163. requested_output_names=['prompt_spk_embedding'],
  164. inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
  165. )
  166. inference_response = inference_request.exec()
  167. if inference_response.has_error():
  168. raise pb_utils.TritonModelException(inference_response.error().message())
  169. # Extract and convert output tensors
  170. prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
  171. prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
  172. return prompt_spk_embedding
  173. def forward_token2wav(
  174. self,
  175. prompt_speech_tokens: torch.Tensor,
  176. prompt_speech_feat: torch.Tensor,
  177. prompt_spk_embedding: torch.Tensor,
  178. target_speech_tokens: torch.Tensor,
  179. request_id: str,
  180. token_offset: int = None,
  181. finalize: bool = None) -> torch.Tensor:
  182. """Forward pass through the vocoder component.
  183. Args:
  184. prompt_speech_tokens: Prompt speech tokens tensor
  185. prompt_speech_feat: Prompt speech feat tensor
  186. prompt_spk_embedding: Prompt spk embedding tensor
  187. target_speech_tokens: Target speech tokens tensor
  188. Returns:
  189. Generated waveform tensor
  190. """
  191. prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
  192. prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
  193. prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
  194. target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
  195. inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
  196. if token_offset is not None:
  197. assert finalize is not None
  198. token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
  199. finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
  200. inputs_tensor.append(token_offset_tensor)
  201. inputs_tensor.append(finalize_tensor)
  202. # Create and execute inference request
  203. inference_request = pb_utils.InferenceRequest(
  204. model_name='token2wav',
  205. requested_output_names=['waveform'],
  206. inputs=inputs_tensor,
  207. request_id=request_id,
  208. )
  209. inference_response = inference_request.exec()
  210. if inference_response.has_error():
  211. raise pb_utils.TritonModelException(inference_response.error().message())
  212. # Extract and convert output waveform
  213. waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
  214. waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
  215. return waveform
  216. def parse_input(self, text, prompt_text, prompt_speech_tokens):
  217. total_text = f"{prompt_text}{text}"
  218. prompt = self.prompt_template.format(input_text=total_text)
  219. input_ids = self.tokenizer.encode(prompt)
  220. input_ids = torch.tensor([input_ids], dtype=torch.int32)
  221. input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
  222. return input_ids
  223. def _extract_speech_feat(self, speech):
  224. speech_feat = mel_spectrogram(
  225. speech,
  226. n_fft=1920,
  227. num_mels=80,
  228. sampling_rate=24000,
  229. hop_size=480,
  230. win_size=1920,
  231. fmin=0,
  232. fmax=8000).squeeze(
  233. dim=0).transpose(
  234. 0,
  235. 1).to(
  236. self.device)
  237. speech_feat = speech_feat.unsqueeze(dim=0)
  238. return speech_feat
  239. def execute(self, requests):
  240. """Execute inference on the batched requests.
  241. Args:
  242. requests: List of inference requests
  243. Returns:
  244. List of inference responses containing generated audio
  245. """
  246. responses = []
  247. for request in requests:
  248. request_id = request.request_id()
  249. # Extract input tensors
  250. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  251. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  252. # Process reference audio through audio tokenizer
  253. prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
  254. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
  255. wav_tensor = wav.as_numpy()
  256. wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
  257. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
  258. speech_feat = self._extract_speech_feat(prompt_speech_resample)
  259. token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
  260. prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
  261. prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
  262. flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
  263. token_hop_len = 25
  264. flow_pre_lookahead_len = 3
  265. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
  266. reference_text = reference_text[0][0].decode('utf-8')
  267. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  268. target_text = target_text[0][0].decode('utf-8')
  269. # Prepare prompt for LLM
  270. input_ids = self.parse_input(
  271. text=target_text,
  272. prompt_text=reference_text,
  273. prompt_speech_tokens=prompt_speech_tokens,
  274. )
  275. # Generate semantic tokens with LLM
  276. generated_ids_iter = self.forward_llm(input_ids)
  277. prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
  278. print(f"here2")
  279. if self.decoupled:
  280. response_sender = request.get_response_sender()
  281. semantic_token_ids_arr, token_offset = [], 0
  282. for generated_ids in generated_ids_iter:
  283. generated_ids = generated_ids.tolist()
  284. print(f"generated_id: {generated_ids}")
  285. semantic_token_ids_arr.extend(generated_ids)
  286. prompt_token_pad = int(np.ceil(flow_prompt_speech_token_len / token_hop_len) * token_hop_len - flow_prompt_speech_token_len)
  287. this_token_hop_len = token_hop_len + prompt_token_pad if token_offset == 0 else token_hop_len
  288. print(f"this_token_hop_len: {this_token_hop_len}")
  289. if len(semantic_token_ids_arr) - token_offset >= this_token_hop_len + flow_pre_lookahead_len:
  290. this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + flow_pre_lookahead_len]
  291. print(f"this_tts_speech_token: {this_tts_speech_token}")
  292. this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
  293. print(f"here3")
  294. sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, False)
  295. print(f"here4")
  296. # Prepare response to send
  297. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  298. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  299. response_sender.send(inference_response)
  300. self.logger.log_info(f"[{request_id}]")
  301. token_offset += this_token_hop_len
  302. print(f"here")
  303. this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
  304. sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
  305. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  306. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  307. response_sender.send(inference_response)
  308. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  309. self.logger.log_info("send tritonserver_response_complete_final to end")
  310. else:
  311. generated_ids = next(generated_ids_iter)
  312. generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
  313. if generated_ids is None or len(generated_ids) == 0:
  314. raise pb_utils.TritonModelException("Generated IDs is None or empty")
  315. audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
  316. # Prepare response
  317. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
  318. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  319. responses.append(inference_response)
  320. if not self.decoupled:
  321. return responses