model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import math
  28. import os
  29. import re
  30. import time
  31. from typing import Dict, List, Tuple, Optional, Union
  32. import asyncio
  33. import httpx
  34. import numpy as np
  35. import torch
  36. from torch.utils.dlpack import from_dlpack, to_dlpack
  37. import triton_python_backend_utils as pb_utils
  38. from transformers import AutoTokenizer
  39. import torchaudio
  40. from matcha.utils.audio import mel_spectrogram
  41. ORIGINAL_VOCAB_SIZE = 151663
  42. torch.set_num_threads(1)
  43. def parse_speech_token_string(response_text: str) -> List[int]:
  44. """
  45. Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
  46. """
  47. speech_tokens = response_text.strip().split('><')
  48. if len(speech_tokens) > 1:
  49. # Add back the missing '<' and '>' for proper parsing
  50. speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
  51. speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
  52. speech_ids = []
  53. for token_str in speech_tokens:
  54. match = re.match(r'<\|s_(\d+)\|>', token_str)
  55. if match:
  56. speech_ids.append(int(match.group(1)))
  57. return speech_ids
  58. class TritonPythonModel:
  59. """Triton Python model for Spark TTS.
  60. This model orchestrates the end-to-end TTS pipeline by coordinating
  61. between audio tokenizer, LLM, and vocoder components.
  62. """
  63. def initialize(self, args):
  64. """Initialize the model.
  65. Args:
  66. args: Dictionary containing model configuration
  67. """
  68. self.logger = pb_utils.Logger
  69. # Parse model parameters
  70. self.model_config = json.loads(args['model_config'])
  71. parameters = self.model_config['parameters']
  72. model_params = {k: v["string_value"] for k, v in parameters.items()}
  73. self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
  74. self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
  75. # Initialize tokenizer
  76. llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
  77. self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
  78. self.prompt_template = "<|sos|>{input_text}<|task_id|>"
  79. self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
  80. self.device = torch.device("cuda")
  81. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  82. self.token_frame_rate = 25
  83. self.flow_pre_lookahead_len = 3
  84. self.token_hop_len = 15
  85. self.http_client = httpx.AsyncClient()
  86. self.api_base = "http://localhost:8000/v1/chat/completions"
  87. self.speaker_cache = {}
  88. def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
  89. """Converts a tensor or list of speech token IDs to a string representation."""
  90. if isinstance(speech_tokens, torch.Tensor):
  91. # Ensure tensor is on CPU and flattened
  92. speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
  93. speech_id_str = ""
  94. for token_id in speech_tokens:
  95. # Convert token ID back to the speech number N
  96. token_num = token_id - ORIGINAL_VOCAB_SIZE
  97. speech_id_str += f"<|s_{token_num}|>"
  98. return speech_id_str
  99. async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
  100. """
  101. Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
  102. """
  103. full_text = f"{reference_text}{target_text}"
  104. prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
  105. chat = [
  106. {"role": "user", "content": full_text},
  107. {"role": "assistant", "content": prompt_speech_tokens_str}
  108. ]
  109. payload = {
  110. "model": "trt_engines_bfloat16",
  111. "messages": chat,
  112. "max_tokens": 750,
  113. "temperature": 0.8,
  114. "top_p": 0.95,
  115. "top_k": 50,
  116. "repetition_penalty": 1.1,
  117. "stop": ["<|eos1|>", "<|eos|>"],
  118. "stream": True,
  119. }
  120. buffer = ""
  121. async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
  122. response.raise_for_status()
  123. async for line in response.aiter_lines():
  124. if line.startswith("data: "):
  125. line_data = line[len("data: "):].strip()
  126. if line_data == "[DONE]":
  127. break
  128. try:
  129. json_data = json.loads(line_data)
  130. content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
  131. if content:
  132. buffer += content
  133. while True:
  134. match = re.search(r"<\|s_(\d+)\|>", buffer)
  135. if not match:
  136. break
  137. token_num = int(match.group(1))
  138. final_id = token_num + ORIGINAL_VOCAB_SIZE
  139. yield final_id
  140. buffer = buffer[match.end():]
  141. except json.JSONDecodeError:
  142. self.logger.log_info(f"Skipping non-JSON line: {line_data}")
  143. continue
  144. # Process any remaining complete tokens in the buffer after the stream ends
  145. while True:
  146. match = re.search(r"<\|s_(\d+)\|>", buffer)
  147. if not match:
  148. break
  149. token_num = int(match.group(1))
  150. final_id = token_num + ORIGINAL_VOCAB_SIZE
  151. yield final_id
  152. buffer = buffer[match.end():]
  153. def forward_audio_tokenizer(self, wav, wav_len):
  154. """Forward pass through the audio tokenizer component.
  155. Args:
  156. wav: Input waveform tensor
  157. wav_len: Waveform length tensor
  158. Returns:
  159. Tuple of global and semantic tokens
  160. """
  161. inference_request = pb_utils.InferenceRequest(
  162. model_name='audio_tokenizer',
  163. requested_output_names=['prompt_speech_tokens'],
  164. inputs=[wav, wav_len]
  165. )
  166. inference_response = inference_request.exec()
  167. if inference_response.has_error():
  168. raise pb_utils.TritonModelException(inference_response.error().message())
  169. # Extract and convert output tensors
  170. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
  171. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  172. return prompt_speech_tokens
  173. def forward_speaker_embedding(self, wav):
  174. """Forward pass through the speaker embedding component.
  175. Args:
  176. wav: Input waveform tensor
  177. Returns:
  178. Prompt speaker embedding tensor
  179. """
  180. inference_request = pb_utils.InferenceRequest(
  181. model_name='speaker_embedding',
  182. requested_output_names=['prompt_spk_embedding'],
  183. inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
  184. )
  185. inference_response = inference_request.exec()
  186. if inference_response.has_error():
  187. raise pb_utils.TritonModelException(inference_response.error().message())
  188. # Extract and convert output tensors
  189. prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
  190. prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
  191. return prompt_spk_embedding
  192. async def forward_token2wav(
  193. self,
  194. index: int,
  195. target_speech_tokens: torch.Tensor,
  196. request_id: str,
  197. reference_wav: object,
  198. reference_wav_len: object,
  199. finalize: bool = None) -> torch.Tensor:
  200. """Forward pass through the vocoder component.
  201. Args:
  202. index: Index of the request
  203. target_speech_tokens: Target speech tokens tensor
  204. request_id: Request ID
  205. reference_wav: Reference waveform tensor
  206. reference_wav_len: Reference waveform length tensor
  207. finalize: Whether to finalize the request
  208. Returns:
  209. Generated waveform tensor
  210. """
  211. target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
  212. finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
  213. inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
  214. # Create and execute inference request
  215. inference_request = pb_utils.InferenceRequest(
  216. model_name='token2wav_dit',
  217. requested_output_names=[
  218. "waveform",
  219. ],
  220. inputs=inputs_tensor,
  221. request_id=request_id,
  222. parameters={"priority": index + 1},
  223. )
  224. inference_response = await inference_request.async_exec()
  225. if inference_response.has_error():
  226. raise pb_utils.TritonModelException(inference_response.error().message())
  227. # Extract and convert output waveform
  228. waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
  229. waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
  230. return waveform
  231. def _extract_speech_feat(self, speech):
  232. speech_feat = mel_spectrogram(
  233. speech,
  234. n_fft=1920,
  235. num_mels=80,
  236. sampling_rate=24000,
  237. hop_size=480,
  238. win_size=1920,
  239. fmin=0,
  240. fmax=8000).squeeze(
  241. dim=0).transpose(
  242. 0,
  243. 1).to(
  244. self.device)
  245. speech_feat = speech_feat.unsqueeze(dim=0)
  246. return speech_feat
  247. async def _process_request(self, request):
  248. request_id = request.request_id()
  249. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
  250. reference_text = reference_text[0][0].decode('utf-8')
  251. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  252. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  253. if reference_text not in self.speaker_cache:
  254. self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
  255. prompt_speech_tokens = self.speaker_cache[reference_text]
  256. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  257. target_text = target_text[0][0].decode('utf-8')
  258. if self.decoupled:
  259. response_sender = request.get_response_sender()
  260. semantic_token_ids_arr = []
  261. token_offset, chunk_index = 0, 0
  262. start_time = time.time()
  263. this_token_hop_len = self.token_hop_len
  264. async for generated_ids in self.forward_llm_async(
  265. target_text=target_text,
  266. reference_text=reference_text,
  267. prompt_speech_tokens=prompt_speech_tokens,
  268. ):
  269. if not generated_ids:
  270. break
  271. semantic_token_ids_arr.append(generated_ids)
  272. while True:
  273. pending_num = len(semantic_token_ids_arr) - token_offset
  274. if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
  275. this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
  276. this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
  277. sub_tts_speech = await self.forward_token2wav(
  278. chunk_index,
  279. this_tts_speech_token, request_id, wav, wav_len, False
  280. )
  281. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  282. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  283. response_sender.send(inference_response)
  284. token_offset += this_token_hop_len
  285. if self.dynamic_chunk_strategy == "exponential":
  286. this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
  287. elif self.dynamic_chunk_strategy == "equal":
  288. this_token_hop_len = self.token_hop_len
  289. elif self.dynamic_chunk_strategy == "time_based":
  290. # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
  291. cost_time = time.time() - start_time
  292. duration = token_offset / self.token_frame_rate
  293. if chunk_index > 0 and cost_time > 0:
  294. avg_chunk_processing_time = cost_time / (chunk_index + 1)
  295. if avg_chunk_processing_time > 0:
  296. multiples = (duration - cost_time) / avg_chunk_processing_time
  297. next_pending_num = len(semantic_token_ids_arr) - token_offset
  298. if multiples > 4:
  299. this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
  300. elif multiples > 2:
  301. this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
  302. else:
  303. this_token_hop_len = self.token_hop_len
  304. this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
  305. chunk_index += 1
  306. else:
  307. break
  308. this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
  309. sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
  310. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  311. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  312. response_sender.send(inference_response)
  313. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  314. else:
  315. raise NotImplementedError("Offline TTS mode is not supported")
  316. async def execute(self, requests):
  317. """Execute inference on the batched requests.
  318. Args:
  319. requests: List of inference requests
  320. Returns:
  321. List of inference responses containing generated audio
  322. """
  323. tasks = [
  324. asyncio.create_task(self._process_request(request))
  325. for request in requests
  326. ]
  327. await asyncio.gather(*tasks)
  328. return None
  329. def finalize(self):
  330. self.logger.log_info("Finalizing CosyVoice DIT model")
  331. if hasattr(self, "http_client"):
  332. asyncio.run(self.http_client.aclose())