model.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import math
  28. import os
  29. import re
  30. import threading
  31. import time
  32. from typing import Dict, List, Tuple, Optional, Union
  33. import numpy as np
  34. import torch
  35. from torch.utils.dlpack import from_dlpack, to_dlpack
  36. import triton_python_backend_utils as pb_utils
  37. from transformers import AutoTokenizer
  38. import torchaudio
  39. from matcha.utils.audio import mel_spectrogram
  40. ORIGINAL_VOCAB_SIZE = 151663
  41. torch.set_num_threads(1)
  42. class TritonPythonModel:
  43. """Triton Python model for Spark TTS.
  44. This model orchestrates the end-to-end TTS pipeline by coordinating
  45. between audio tokenizer, LLM, and vocoder components.
  46. """
  47. def initialize(self, args):
  48. """Initialize the model.
  49. Args:
  50. args: Dictionary containing model configuration
  51. """
  52. self.logger = pb_utils.Logger
  53. # Parse model parameters
  54. self.model_config = json.loads(args['model_config'])
  55. parameters = self.model_config['parameters']
  56. model_params = {k: v["string_value"] for k, v in parameters.items()}
  57. self.logger.log_info(f"model_params:{model_params}")
  58. self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
  59. self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
  60. # Initialize tokenizer
  61. llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
  62. self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
  63. self.prompt_template = "<|sos|>{input_text}<|task_id|>"
  64. self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
  65. self.device = torch.device("cuda")
  66. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  67. self.token_frame_rate = 25
  68. self.flow_pre_lookahead_len = 3
  69. self.token_hop_len = 15
  70. spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
  71. if not os.path.exists(spk_info_path):
  72. raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
  73. spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
  74. self.default_spk_info = spk_info["001"]
  75. def forward_llm(self, input_ids):
  76. """
  77. Prepares the response from the language model based on the provided
  78. inputs. Creates a `pb_utils.InferenceRequest` object with passed
  79. `llm_request_inputs` to send to a decoupled TensorRTLLM model.
  80. For each response from the language model:
  81. - Checks for errors and raise an exception if any are found.
  82. - Extracts the "output_ids" tensor from the response.
  83. - Determines the finish reason based on the presence of the
  84. end-of-sequence token or reaching the maximum length.
  85. - Appends the generated token IDs to `output_ids`.
  86. - If the finish reason is determined, decodes the output IDs to text
  87. and prepares the final response.
  88. The final response includes the generated text, finish reason,
  89. completion tokens, prompt tokens, and total tokens.
  90. Parameters
  91. ----------
  92. - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
  93. Returns
  94. -------
  95. - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
  96. """
  97. # convert input_ids to numpy, with shape [1, sequence_length]
  98. input_ids = input_ids.cpu().numpy()
  99. max_tokens = 750
  100. input_dict = {
  101. "request_output_len": np.array([[max_tokens]], dtype=np.int32),
  102. "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
  103. "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
  104. "streaming": np.array([[self.decoupled]], dtype=np.bool_),
  105. "runtime_top_p": np.array([[0.95]], dtype=np.float32),
  106. "runtime_top_k": np.array([[50]], dtype=np.int32),
  107. "temperature": np.array([[0.8]], dtype=np.float32),
  108. "repetition_penalty": np.array([[1.1]], dtype=np.float32),
  109. "random_seed": np.array([[42]], dtype=np.uint64),
  110. "input_ids": input_ids,
  111. "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
  112. }
  113. # Convert inputs to Triton tensors
  114. input_tensor_list = [
  115. pb_utils.Tensor(k, v) for k, v in input_dict.items()
  116. ]
  117. # Create and execute inference request
  118. llm_request = pb_utils.InferenceRequest(
  119. model_name="tensorrt_llm",
  120. requested_output_names=["output_ids", "sequence_length"],
  121. inputs=input_tensor_list,
  122. )
  123. llm_responses = llm_request.exec(decoupled=self.decoupled)
  124. if self.decoupled:
  125. for llm_response in llm_responses:
  126. if llm_response.has_error():
  127. raise pb_utils.TritonModelException(llm_response.error().message())
  128. # Extract and process output
  129. output_ids = pb_utils.get_output_tensor_by_name(
  130. llm_response, "output_ids").as_numpy()
  131. seq_lens = pb_utils.get_output_tensor_by_name(
  132. llm_response, "sequence_length").as_numpy()
  133. # Get actual output IDs up to the sequence length
  134. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  135. yield actual_output_ids
  136. else:
  137. llm_response = llm_responses
  138. if llm_response.has_error():
  139. raise pb_utils.TritonModelException(llm_response.error().message())
  140. # Extract and process output
  141. output_ids = pb_utils.get_output_tensor_by_name(
  142. llm_response, "output_ids").as_numpy()
  143. seq_lens = pb_utils.get_output_tensor_by_name(
  144. llm_response, "sequence_length").as_numpy()
  145. # Get actual output IDs up to the sequence length
  146. actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
  147. yield actual_output_ids
  148. def forward_audio_tokenizer(self, wav, wav_len):
  149. """Forward pass through the audio tokenizer component.
  150. Args:
  151. wav: Input waveform tensor
  152. wav_len: Waveform length tensor
  153. Returns:
  154. Tuple of global and semantic tokens
  155. """
  156. inference_request = pb_utils.InferenceRequest(
  157. model_name='audio_tokenizer',
  158. requested_output_names=['prompt_speech_tokens'],
  159. inputs=[wav, wav_len]
  160. )
  161. inference_response = inference_request.exec()
  162. if inference_response.has_error():
  163. raise pb_utils.TritonModelException(inference_response.error().message())
  164. # Extract and convert output tensors
  165. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
  166. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  167. return prompt_speech_tokens
  168. def forward_speaker_embedding(self, wav):
  169. """Forward pass through the speaker embedding component.
  170. Args:
  171. wav: Input waveform tensor
  172. Returns:
  173. Prompt speaker embedding tensor
  174. """
  175. inference_request = pb_utils.InferenceRequest(
  176. model_name='speaker_embedding',
  177. requested_output_names=['prompt_spk_embedding'],
  178. inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
  179. )
  180. inference_response = inference_request.exec()
  181. if inference_response.has_error():
  182. raise pb_utils.TritonModelException(inference_response.error().message())
  183. # Extract and convert output tensors
  184. prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
  185. prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
  186. return prompt_spk_embedding
  187. def forward_token2wav(
  188. self,
  189. target_speech_tokens: torch.Tensor,
  190. request_id: str,
  191. prompt_speech_tokens: torch.Tensor = None,
  192. prompt_speech_feat: torch.Tensor = None,
  193. prompt_spk_embedding: torch.Tensor = None,
  194. token_offset: int = None,
  195. finalize: bool = None) -> torch.Tensor:
  196. """Forward pass through the vocoder component.
  197. Args:
  198. prompt_speech_tokens: Prompt speech tokens tensor
  199. prompt_speech_feat: Prompt speech feat tensor
  200. prompt_spk_embedding: Prompt spk embedding tensor
  201. target_speech_tokens: Target speech tokens tensor
  202. Returns:
  203. Generated waveform tensor
  204. """
  205. target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
  206. inputs_tensor = [target_speech_tokens_tensor]
  207. if token_offset is not None:
  208. assert finalize is not None
  209. token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
  210. finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
  211. inputs_tensor.append(token_offset_tensor)
  212. inputs_tensor.append(finalize_tensor)
  213. if prompt_spk_embedding is not None:
  214. assert prompt_speech_feat is not None
  215. prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
  216. prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
  217. prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
  218. inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
  219. # Create and execute inference request
  220. inference_request = pb_utils.InferenceRequest(
  221. model_name='token2wav',
  222. requested_output_names=['waveform'],
  223. inputs=inputs_tensor,
  224. request_id=request_id,
  225. )
  226. inference_response = inference_request.exec()
  227. if inference_response.has_error():
  228. raise pb_utils.TritonModelException(inference_response.error().message())
  229. # Extract and convert output waveform
  230. waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
  231. waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
  232. return waveform
  233. def parse_input(self, text, prompt_text, prompt_speech_tokens):
  234. total_text = f"{prompt_text}{text}"
  235. prompt = self.prompt_template.format(input_text=total_text)
  236. input_ids = self.tokenizer.encode(prompt)
  237. input_ids = torch.tensor([input_ids], dtype=torch.int32)
  238. input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
  239. return input_ids
  240. def _extract_speech_feat(self, speech):
  241. speech_feat = mel_spectrogram(
  242. speech,
  243. n_fft=1920,
  244. num_mels=80,
  245. sampling_rate=24000,
  246. hop_size=480,
  247. win_size=1920,
  248. fmin=0,
  249. fmax=8000).squeeze(
  250. dim=0).transpose(
  251. 0,
  252. 1).to(
  253. self.device)
  254. speech_feat = speech_feat.unsqueeze(dim=0)
  255. return speech_feat
  256. def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
  257. for generated_ids in generated_ids_iter:
  258. generated_ids = generated_ids.tolist()
  259. if len(generated_ids) == 0:
  260. break
  261. semantic_token_ids_arr.extend(generated_ids)
  262. llm_is_done_flag[0] = True
  263. def execute(self, requests):
  264. """Execute inference on the batched requests.
  265. Args:
  266. requests: List of inference requests
  267. Returns:
  268. List of inference responses containing generated audio
  269. """
  270. responses = []
  271. for request in requests:
  272. request_id = request.request_id()
  273. # Extract input tensors
  274. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  275. # Process reference audio through audio tokenizer
  276. if wav is not None:
  277. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  278. prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
  279. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
  280. wav_tensor = wav.as_numpy()
  281. wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
  282. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
  283. speech_feat = self._extract_speech_feat(prompt_speech_resample)
  284. token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
  285. prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
  286. prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
  287. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
  288. reference_text = reference_text[0][0].decode('utf-8')
  289. prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
  290. else:
  291. # using pre-cached reference text
  292. reference_text = self.default_spk_info["prompt_text"]
  293. prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
  294. prompt_speech_feat = None
  295. prompt_spk_embedding = None
  296. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  297. target_text = target_text[0][0].decode('utf-8')
  298. # Prepare prompt for LLM
  299. input_ids = self.parse_input(
  300. text=target_text,
  301. prompt_text=reference_text,
  302. prompt_speech_tokens=prompt_speech_tokens,
  303. )
  304. # Generate semantic tokens with LLM
  305. generated_ids_iter = self.forward_llm(input_ids)
  306. if self.decoupled:
  307. response_sender = request.get_response_sender()
  308. semantic_token_ids_arr = []
  309. llm_is_done_flag = [False]
  310. llm_thread = threading.Thread(
  311. target=self._llm_gen_thread,
  312. args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
  313. )
  314. llm_thread.start()
  315. token_offset, chunk_index = 0, 0
  316. start_time = time.time()
  317. this_token_hop_len = self.token_hop_len
  318. while True:
  319. pending_num = len(semantic_token_ids_arr) - token_offset
  320. if llm_is_done_flag[0]:
  321. break
  322. if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
  323. this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
  324. this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
  325. sub_tts_speech = self.forward_token2wav(
  326. this_tts_speech_token, request_id, prompt_speech_tokens,
  327. prompt_speech_feat, prompt_spk_embedding, token_offset, False
  328. )
  329. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  330. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  331. response_sender.send(inference_response)
  332. token_offset += this_token_hop_len
  333. self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
  334. if self.dynamic_chunk_strategy == "exponential":
  335. this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
  336. elif self.dynamic_chunk_strategy == "time_based":
  337. # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
  338. cost_time = time.time() - start_time
  339. duration = token_offset / self.token_frame_rate
  340. if chunk_index > 0 and cost_time > 0:
  341. avg_chunk_processing_time = cost_time / (chunk_index + 1)
  342. if avg_chunk_processing_time > 0:
  343. multiples = (duration - cost_time) / avg_chunk_processing_time
  344. self.logger.log_info(f"multiples: {multiples}")
  345. next_pending_num = len(semantic_token_ids_arr) - token_offset
  346. if multiples > 4:
  347. this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
  348. elif multiples > 2:
  349. this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
  350. else:
  351. this_token_hop_len = self.token_hop_len
  352. this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
  353. chunk_index += 1
  354. else:
  355. time.sleep(0.02)
  356. this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
  357. sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
  358. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  359. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  360. response_sender.send(inference_response)
  361. llm_thread.join()
  362. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  363. self.logger.log_info("send tritonserver_response_complete_final to end")
  364. else:
  365. generated_ids = next(generated_ids_iter)
  366. generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
  367. if generated_ids is None or len(generated_ids) == 0:
  368. raise pb_utils.TritonModelException("Generated IDs is None or empty")
  369. audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
  370. # Prepare response
  371. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
  372. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  373. responses.append(inference_response)
  374. if not self.decoupled:
  375. return responses