model.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import math
  28. import os
  29. import re
  30. import threading
  31. import time
  32. from typing import Dict, List, Tuple, Optional, Union
  33. import numpy as np
  34. import torch
  35. from torch.utils.dlpack import from_dlpack, to_dlpack
  36. import triton_python_backend_utils as pb_utils
  37. from transformers import AutoTokenizer
  38. import torchaudio
  39. from matcha.utils.audio import mel_spectrogram
  40. torch.set_num_threads(1)
  41. class TritonPythonModel:
  42. """Triton Python model for Spark TTS.
  43. This model orchestrates the end-to-end TTS pipeline by coordinating
  44. between audio tokenizer, LLM, and vocoder components.
  45. """
  46. def initialize(self, args):
  47. """Initialize the model.
  48. Args:
  49. args: Dictionary containing model configuration
  50. """
  51. self.logger = pb_utils.Logger
  52. # Parse model parameters
  53. self.model_config = json.loads(args['model_config'])
  54. parameters = self.model_config['parameters']
  55. model_params = {k: v["string_value"] for k, v in parameters.items()}
  56. self.logger.log_info(f"model_params:{model_params}")
  57. self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
  58. self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
  59. # Initialize tokenizer
  60. llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
  61. self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
  62. self.prompt_template = "<|sos|>{input_text}<|task_id|>"
  63. self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
  64. self.device = torch.device("cuda")
  65. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  66. self.token_frame_rate = 25
  67. self.flow_pre_lookahead_len = 3
  68. self.token_hop_len = 15
  69. def forward_llm(self, input_ids):
  70. """
  71. Prepares the response from the language model based on the provided
  72. inputs. Creates a `pb_utils.InferenceRequest` object with passed
  73. `llm_request_inputs` to send to a decoupled TensorRTLLM model.
  74. For each response from the language model:
  75. - Checks for errors and raise an exception if any are found.
  76. - Extracts the "output_ids" tensor from the response.
  77. - Determines the finish reason based on the presence of the
  78. end-of-sequence token or reaching the maximum length.
  79. - Appends the generated token IDs to `output_ids`.
  80. - If the finish reason is determined, decodes the output IDs to text
  81. and prepares the final response.
  82. The final response includes the generated text, finish reason,
  83. completion tokens, prompt tokens, and total tokens.
  84. Parameters
  85. ----------
  86. - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
  87. Returns
  88. -------
  89. - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
  90. """
  91. # convert input_ids to numpy, with shape [1, sequence_length]
  92. input_ids = input_ids.cpu().numpy()
  93. max_tokens = 750
  94. input_dict = {
  95. "request_output_len": np.array([[max_tokens]], dtype=np.int32),
  96. "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
  97. "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
  98. "streaming": np.array([[self.decoupled]], dtype=np.bool_),
  99. "runtime_top_p": np.array([[0.95]], dtype=np.float32),
  100. "runtime_top_k": np.array([[50]], dtype=np.int32),
  101. "temperature": np.array([[0.8]], dtype=np.float32),
  102. "repetition_penalty": np.array([[1.1]], dtype=np.float32),
  103. "random_seed": np.array([[42]], dtype=np.uint64),
  104. "input_ids": input_ids,
  105. "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
  106. }
  107. # Convert inputs to Triton tensors
  108. input_tensor_list = [
  109. pb_utils.Tensor(k, v) for k, v in input_dict.items()
  110. ]
  111. # Create and execute inference request
  112. llm_request = pb_utils.InferenceRequest(
  113. model_name="tensorrt_llm",
  114. requested_output_names=["output_ids", "sequence_length"],
  115. inputs=input_tensor_list,
  116. )
  117. llm_responses = llm_request.exec(decoupled=self.decoupled)
  118. if self.decoupled:
  119. for llm_response in llm_responses:
  120. if llm_response.has_error():
  121. raise pb_utils.TritonModelException(llm_response.error().message())
  122. # Extract and process output
  123. output_ids = pb_utils.get_output_tensor_by_name(
  124. llm_response, "output_ids").as_numpy()
  125. seq_lens = pb_utils.get_output_tensor_by_name(
  126. llm_response, "sequence_length").as_numpy()
  127. # Get actual output IDs up to the sequence length
  128. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  129. yield actual_output_ids
  130. else:
  131. llm_response = llm_responses
  132. if llm_response.has_error():
  133. raise pb_utils.TritonModelException(llm_response.error().message())
  134. # Extract and process output
  135. output_ids = pb_utils.get_output_tensor_by_name(
  136. llm_response, "output_ids").as_numpy()
  137. seq_lens = pb_utils.get_output_tensor_by_name(
  138. llm_response, "sequence_length").as_numpy()
  139. # Get actual output IDs up to the sequence length
  140. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  141. yield actual_output_ids
  142. def forward_audio_tokenizer(self, wav, wav_len):
  143. """Forward pass through the audio tokenizer component.
  144. Args:
  145. wav: Input waveform tensor
  146. wav_len: Waveform length tensor
  147. Returns:
  148. Tuple of global and semantic tokens
  149. """
  150. inference_request = pb_utils.InferenceRequest(
  151. model_name='audio_tokenizer',
  152. requested_output_names=['prompt_speech_tokens'],
  153. inputs=[wav, wav_len]
  154. )
  155. inference_response = inference_request.exec()
  156. if inference_response.has_error():
  157. raise pb_utils.TritonModelException(inference_response.error().message())
  158. # Extract and convert output tensors
  159. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
  160. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  161. return prompt_speech_tokens
  162. def forward_speaker_embedding(self, wav):
  163. """Forward pass through the speaker embedding component.
  164. Args:
  165. wav: Input waveform tensor
  166. Returns:
  167. Prompt speaker embedding tensor
  168. """
  169. inference_request = pb_utils.InferenceRequest(
  170. model_name='speaker_embedding',
  171. requested_output_names=['prompt_spk_embedding'],
  172. inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
  173. )
  174. inference_response = inference_request.exec()
  175. if inference_response.has_error():
  176. raise pb_utils.TritonModelException(inference_response.error().message())
  177. # Extract and convert output tensors
  178. prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
  179. prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
  180. return prompt_spk_embedding
  181. def forward_token2wav(
  182. self,
  183. prompt_speech_tokens: torch.Tensor,
  184. prompt_speech_feat: torch.Tensor,
  185. prompt_spk_embedding: torch.Tensor,
  186. target_speech_tokens: torch.Tensor,
  187. request_id: str,
  188. token_offset: int = None,
  189. finalize: bool = None) -> torch.Tensor:
  190. """Forward pass through the vocoder component.
  191. Args:
  192. prompt_speech_tokens: Prompt speech tokens tensor
  193. prompt_speech_feat: Prompt speech feat tensor
  194. prompt_spk_embedding: Prompt spk embedding tensor
  195. target_speech_tokens: Target speech tokens tensor
  196. Returns:
  197. Generated waveform tensor
  198. """
  199. prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
  200. prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
  201. prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
  202. target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
  203. inputs_tensor = [prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
  204. if token_offset is not None:
  205. assert finalize is not None
  206. token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
  207. finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
  208. inputs_tensor.append(token_offset_tensor)
  209. inputs_tensor.append(finalize_tensor)
  210. # Create and execute inference request
  211. inference_request = pb_utils.InferenceRequest(
  212. model_name='token2wav',
  213. requested_output_names=['waveform'],
  214. inputs=inputs_tensor,
  215. request_id=request_id,
  216. )
  217. inference_response = inference_request.exec()
  218. if inference_response.has_error():
  219. raise pb_utils.TritonModelException(inference_response.error().message())
  220. # Extract and convert output waveform
  221. waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
  222. waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
  223. return waveform
  224. def parse_input(self, text, prompt_text, prompt_speech_tokens):
  225. total_text = f"{prompt_text}{text}"
  226. prompt = self.prompt_template.format(input_text=total_text)
  227. input_ids = self.tokenizer.encode(prompt)
  228. input_ids = torch.tensor([input_ids], dtype=torch.int32)
  229. input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
  230. return input_ids
  231. def _extract_speech_feat(self, speech):
  232. speech_feat = mel_spectrogram(
  233. speech,
  234. n_fft=1920,
  235. num_mels=80,
  236. sampling_rate=24000,
  237. hop_size=480,
  238. win_size=1920,
  239. fmin=0,
  240. fmax=8000).squeeze(
  241. dim=0).transpose(
  242. 0,
  243. 1).to(
  244. self.device)
  245. speech_feat = speech_feat.unsqueeze(dim=0)
  246. return speech_feat
  247. def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
  248. for generated_ids in generated_ids_iter:
  249. generated_ids = generated_ids.tolist()
  250. if len(generated_ids) == 0:
  251. break
  252. semantic_token_ids_arr.extend(generated_ids)
  253. llm_is_done_flag[0] = True
  254. def execute(self, requests):
  255. """Execute inference on the batched requests.
  256. Args:
  257. requests: List of inference requests
  258. Returns:
  259. List of inference responses containing generated audio
  260. """
  261. responses = []
  262. for request in requests:
  263. request_id = request.request_id()
  264. # Extract input tensors
  265. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  266. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  267. # Process reference audio through audio tokenizer
  268. prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
  269. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
  270. wav_tensor = wav.as_numpy()
  271. wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
  272. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
  273. speech_feat = self._extract_speech_feat(prompt_speech_resample)
  274. token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
  275. prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
  276. prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
  277. flow_prompt_speech_token_len = prompt_speech_tokens.shape[-1]
  278. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
  279. reference_text = reference_text[0][0].decode('utf-8')
  280. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  281. target_text = target_text[0][0].decode('utf-8')
  282. # Prepare prompt for LLM
  283. input_ids = self.parse_input(
  284. text=target_text,
  285. prompt_text=reference_text,
  286. prompt_speech_tokens=prompt_speech_tokens,
  287. )
  288. # Generate semantic tokens with LLM
  289. generated_ids_iter = self.forward_llm(input_ids)
  290. prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
  291. if self.decoupled:
  292. response_sender = request.get_response_sender()
  293. semantic_token_ids_arr = []
  294. llm_is_done_flag = [False]
  295. llm_thread = threading.Thread(
  296. target=self._llm_gen_thread,
  297. args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
  298. )
  299. llm_thread.start()
  300. token_offset, chunk_index = 0, 0
  301. start_time = time.time()
  302. this_token_hop_len = self.token_hop_len
  303. while True:
  304. pending_num = len(semantic_token_ids_arr) - token_offset
  305. if llm_is_done_flag[0]:
  306. break
  307. if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
  308. this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
  309. this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
  310. sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, False)
  311. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  312. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  313. response_sender.send(inference_response)
  314. token_offset += this_token_hop_len
  315. self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
  316. if self.dynamic_chunk_strategy == "exponential":
  317. this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
  318. elif self.dynamic_chunk_strategy == "time_based":
  319. # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
  320. cost_time = time.time() - start_time
  321. duration = token_offset / self.token_frame_rate
  322. if chunk_index > 0 and cost_time > 0:
  323. avg_chunk_processing_time = cost_time / (chunk_index + 1)
  324. if avg_chunk_processing_time > 0:
  325. multiples = (duration - cost_time) / avg_chunk_processing_time
  326. self.logger.log_info(f"multiples: {multiples}")
  327. next_pending_num = len(semantic_token_ids_arr) - token_offset
  328. if multiples > 4:
  329. this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
  330. elif multiples > 2:
  331. this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
  332. else:
  333. this_token_hop_len = self.token_hop_len
  334. this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
  335. chunk_index += 1
  336. else:
  337. time.sleep(0.02)
  338. this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
  339. sub_tts_speech = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, this_tts_speech_token, request_id, token_offset, True)
  340. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  341. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  342. response_sender.send(inference_response)
  343. llm_thread.join()
  344. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  345. self.logger.log_info("send tritonserver_response_complete_final to end")
  346. else:
  347. generated_ids = next(generated_ids_iter)
  348. generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
  349. if generated_ids is None or len(generated_ids) == 0:
  350. raise pb_utils.TritonModelException("Generated IDs is None or empty")
  351. audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids, request_id)
  352. # Prepare response
  353. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
  354. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  355. responses.append(inference_response)
  356. if not self.decoupled:
  357. return responses