model.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import os
  28. import threading
  29. import time
  30. from uuid import uuid4
  31. import numpy as np
  32. import torch
  33. from torch.utils.dlpack import from_dlpack, to_dlpack
  34. import triton_python_backend_utils as pb_utils
  35. from transformers import AutoTokenizer
  36. import torchaudio
  37. from matcha.utils.audio import mel_spectrogram
  38. ORIGINAL_VOCAB_SIZE = 151663
  39. torch.set_num_threads(1)
  40. class TritonPythonModel:
  41. """Triton Python model for Spark TTS.
  42. This model orchestrates the end-to-end TTS pipeline by coordinating
  43. between audio tokenizer, LLM, and vocoder components.
  44. """
  45. def initialize(self, args):
  46. """Initialize the model.
  47. Args:
  48. args: Dictionary containing model configuration
  49. """
  50. self.logger = pb_utils.Logger
  51. # Parse model parameters
  52. self.model_config = json.loads(args['model_config'])
  53. parameters = self.model_config['parameters']
  54. model_params = {k: v["string_value"] for k, v in parameters.items()}
  55. self.logger.log_info(f"model_params:{model_params}")
  56. self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
  57. self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
  58. # Initialize tokenizer
  59. llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
  60. self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
  61. self.prompt_template = "<|sos|>{input_text}<|task_id|>"
  62. self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
  63. self.device = torch.device("cuda")
  64. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  65. self.token_frame_rate = 25
  66. self.flow_pre_lookahead_len = 3
  67. self.token_hop_len = 15
  68. spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
  69. if not os.path.exists(spk_info_path):
  70. raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
  71. spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
  72. self.default_spk_info = spk_info["001"]
  73. def forward_llm(self, input_ids):
  74. """
  75. Prepares the response from the language model based on the provided
  76. inputs. Creates a `pb_utils.InferenceRequest` object with passed
  77. `llm_request_inputs` to send to a decoupled TensorRTLLM model.
  78. For each response from the language model:
  79. - Checks for errors and raise an exception if any are found.
  80. - Extracts the "output_ids" tensor from the response.
  81. - Determines the finish reason based on the presence of the
  82. end-of-sequence token or reaching the maximum length.
  83. - Appends the generated token IDs to `output_ids`.
  84. - If the finish reason is determined, decodes the output IDs to text
  85. and prepares the final response.
  86. The final response includes the generated text, finish reason,
  87. completion tokens, prompt tokens, and total tokens.
  88. Parameters
  89. ----------
  90. - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
  91. Returns
  92. -------
  93. - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
  94. """
  95. # convert input_ids to numpy, with shape [1, sequence_length]
  96. input_ids = input_ids.cpu().numpy()
  97. max_tokens = 750
  98. input_dict = {
  99. "request_output_len": np.array([[max_tokens]], dtype=np.int32),
  100. "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
  101. "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
  102. "streaming": np.array([[self.decoupled]], dtype=np.bool_),
  103. "runtime_top_p": np.array([[0.95]], dtype=np.float32),
  104. "runtime_top_k": np.array([[50]], dtype=np.int32),
  105. "temperature": np.array([[0.8]], dtype=np.float32),
  106. "repetition_penalty": np.array([[1.1]], dtype=np.float32),
  107. "random_seed": np.array([[42]], dtype=np.uint64),
  108. "input_ids": input_ids,
  109. "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
  110. }
  111. # Convert inputs to Triton tensors
  112. input_tensor_list = [
  113. pb_utils.Tensor(k, v) for k, v in input_dict.items()
  114. ]
  115. # Create and execute inference request
  116. llm_request = pb_utils.InferenceRequest(
  117. model_name="tensorrt_llm",
  118. requested_output_names=["output_ids", "sequence_length"],
  119. inputs=input_tensor_list,
  120. )
  121. llm_responses = llm_request.exec(decoupled=self.decoupled)
  122. if self.decoupled:
  123. for llm_response in llm_responses:
  124. if llm_response.has_error():
  125. raise pb_utils.TritonModelException(llm_response.error().message())
  126. # Extract and process output
  127. output_ids = pb_utils.get_output_tensor_by_name(
  128. llm_response, "output_ids").as_numpy()
  129. seq_lens = pb_utils.get_output_tensor_by_name(
  130. llm_response, "sequence_length").as_numpy()
  131. # Get actual output IDs up to the sequence length
  132. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  133. yield actual_output_ids
  134. else:
  135. llm_response = llm_responses
  136. if llm_response.has_error():
  137. raise pb_utils.TritonModelException(llm_response.error().message())
  138. # Extract and process output
  139. output_ids = pb_utils.get_output_tensor_by_name(
  140. llm_response, "output_ids").as_numpy()
  141. seq_lens = pb_utils.get_output_tensor_by_name(
  142. llm_response, "sequence_length").as_numpy()
  143. # Get actual output IDs up to the sequence length
  144. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  145. yield actual_output_ids
  146. def forward_audio_tokenizer(self, wav, wav_len):
  147. """Forward pass through the audio tokenizer component.
  148. Args:
  149. wav: Input waveform tensor
  150. wav_len: Waveform length tensor
  151. Returns:
  152. Tuple of global and semantic tokens
  153. """
  154. inference_request = pb_utils.InferenceRequest(
  155. model_name='audio_tokenizer',
  156. requested_output_names=['prompt_speech_tokens'],
  157. inputs=[wav, wav_len]
  158. )
  159. inference_response = inference_request.exec()
  160. if inference_response.has_error():
  161. raise pb_utils.TritonModelException(inference_response.error().message())
  162. # Extract and convert output tensors
  163. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
  164. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  165. return prompt_speech_tokens
  166. def forward_speaker_embedding(self, wav):
  167. """Forward pass through the speaker embedding component.
  168. Args:
  169. wav: Input waveform tensor
  170. Returns:
  171. Prompt speaker embedding tensor
  172. """
  173. inference_request = pb_utils.InferenceRequest(
  174. model_name='speaker_embedding',
  175. requested_output_names=['prompt_spk_embedding'],
  176. inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
  177. )
  178. inference_response = inference_request.exec()
  179. if inference_response.has_error():
  180. raise pb_utils.TritonModelException(inference_response.error().message())
  181. # Extract and convert output tensors
  182. prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
  183. prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
  184. return prompt_spk_embedding
  185. def forward_token2wav(
  186. self,
  187. target_speech_tokens: torch.Tensor,
  188. request_id: str,
  189. prompt_speech_tokens: torch.Tensor = None,
  190. prompt_speech_feat: torch.Tensor = None,
  191. prompt_spk_embedding: torch.Tensor = None,
  192. token_offset: int = None,
  193. finalize: bool = None) -> torch.Tensor:
  194. """Forward pass through the vocoder component.
  195. Args:
  196. prompt_speech_tokens: Prompt speech tokens tensor
  197. prompt_speech_feat: Prompt speech feat tensor
  198. prompt_spk_embedding: Prompt spk embedding tensor
  199. target_speech_tokens: Target speech tokens tensor
  200. Returns:
  201. Generated waveform tensor
  202. """
  203. target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
  204. inputs_tensor = [target_speech_tokens_tensor]
  205. if token_offset is not None:
  206. assert finalize is not None
  207. token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
  208. finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
  209. inputs_tensor.append(token_offset_tensor)
  210. inputs_tensor.append(finalize_tensor)
  211. if prompt_spk_embedding is not None:
  212. assert prompt_speech_feat is not None
  213. prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
  214. prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
  215. prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
  216. inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
  217. # Create and execute inference request
  218. inference_request = pb_utils.InferenceRequest(
  219. model_name='token2wav',
  220. requested_output_names=['waveform'],
  221. inputs=inputs_tensor,
  222. request_id=request_id,
  223. )
  224. inference_response = inference_request.exec()
  225. if inference_response.has_error():
  226. raise pb_utils.TritonModelException(inference_response.error().message())
  227. # Extract and convert output waveform
  228. waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
  229. waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
  230. return waveform
  231. def parse_input(self, text, prompt_text, prompt_speech_tokens):
  232. total_text = f"{prompt_text}{text}"
  233. prompt = self.prompt_template.format(input_text=total_text)
  234. input_ids = self.tokenizer.encode(prompt)
  235. input_ids = torch.tensor([input_ids], dtype=torch.int32)
  236. input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
  237. return input_ids
  238. def _extract_speech_feat(self, speech):
  239. speech_feat = mel_spectrogram(
  240. speech,
  241. n_fft=1920,
  242. num_mels=80,
  243. sampling_rate=24000,
  244. hop_size=480,
  245. win_size=1920,
  246. fmin=0,
  247. fmax=8000).squeeze(
  248. dim=0).transpose(
  249. 0,
  250. 1).to(
  251. self.device)
  252. speech_feat = speech_feat.unsqueeze(dim=0)
  253. return speech_feat
  254. def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
  255. for generated_ids in generated_ids_iter:
  256. generated_ids = generated_ids.tolist()
  257. if len(generated_ids) == 0:
  258. break
  259. semantic_token_ids_arr.extend(generated_ids)
  260. llm_is_done_flag[0] = True
  261. def execute(self, requests):
  262. """Execute inference on the batched requests.
  263. Args:
  264. requests: List of inference requests
  265. Returns:
  266. List of inference responses containing generated audio
  267. """
  268. responses = []
  269. for request in requests:
  270. request_id = request.request_id()
  271. # Extract input tensors
  272. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  273. # Process reference audio through audio tokenizer
  274. if wav is not None:
  275. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  276. prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
  277. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
  278. wav_tensor = wav.as_numpy()
  279. wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
  280. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
  281. speech_feat = self._extract_speech_feat(prompt_speech_resample)
  282. token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
  283. prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
  284. prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
  285. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
  286. reference_text = reference_text[0][0].decode('utf-8')
  287. prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
  288. else:
  289. # using pre-cached reference text
  290. reference_text = self.default_spk_info["prompt_text"]
  291. prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
  292. prompt_speech_feat = None
  293. prompt_spk_embedding = None
  294. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  295. target_text = target_text[0][0].decode('utf-8')
  296. # Prepare prompt for LLM
  297. input_ids = self.parse_input(
  298. text=target_text,
  299. prompt_text=reference_text,
  300. prompt_speech_tokens=prompt_speech_tokens,
  301. )
  302. # Generate semantic tokens with LLM
  303. generated_ids_iter = self.forward_llm(input_ids)
  304. token2wav_request_id = request_id or str(uuid4())
  305. if self.decoupled:
  306. response_sender = request.get_response_sender()
  307. semantic_token_ids_arr = []
  308. llm_is_done_flag = [False]
  309. llm_thread = threading.Thread(
  310. target=self._llm_gen_thread,
  311. args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
  312. )
  313. llm_thread.start()
  314. token_offset, chunk_index = 0, 0
  315. start_time = time.time()
  316. this_token_hop_len = self.token_hop_len
  317. while True:
  318. pending_num = len(semantic_token_ids_arr) - token_offset
  319. if llm_is_done_flag[0]:
  320. break
  321. if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
  322. this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
  323. this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
  324. sub_tts_speech = self.forward_token2wav(
  325. this_tts_speech_token, token2wav_request_id, prompt_speech_tokens,
  326. prompt_speech_feat, prompt_spk_embedding, token_offset, False
  327. )
  328. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  329. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  330. response_sender.send(inference_response)
  331. token_offset += this_token_hop_len
  332. self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
  333. if self.dynamic_chunk_strategy == "exponential":
  334. this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
  335. elif self.dynamic_chunk_strategy == "time_based":
  336. # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
  337. cost_time = time.time() - start_time
  338. duration = token_offset / self.token_frame_rate
  339. if chunk_index > 0 and cost_time > 0:
  340. avg_chunk_processing_time = cost_time / (chunk_index + 1)
  341. if avg_chunk_processing_time > 0:
  342. multiples = (duration - cost_time) / avg_chunk_processing_time
  343. self.logger.log_info(f"multiples: {multiples}")
  344. next_pending_num = len(semantic_token_ids_arr) - token_offset
  345. if multiples > 4:
  346. this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
  347. elif multiples > 2:
  348. this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
  349. else:
  350. this_token_hop_len = self.token_hop_len
  351. this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
  352. chunk_index += 1
  353. else:
  354. time.sleep(0.02)
  355. this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
  356. sub_tts_speech = self.forward_token2wav(this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
  357. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  358. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  359. response_sender.send(inference_response)
  360. llm_thread.join()
  361. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  362. self.logger.log_info("send tritonserver_response_complete_final to end")
  363. else:
  364. generated_ids = next(generated_ids_iter)
  365. generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
  366. if generated_ids is None or len(generated_ids) == 0:
  367. raise pb_utils.TritonModelException("Generated IDs is None or empty")
  368. audio = self.forward_token2wav(generated_ids, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
  369. # Prepare response
  370. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
  371. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  372. responses.append(inference_response)
  373. if not self.decoupled:
  374. return responses