model.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import math
  28. import os
  29. import re
  30. import time
  31. from typing import Dict, List, Tuple, Optional, Union
  32. import asyncio
  33. import httpx
  34. import numpy as np
  35. import torch
  36. from torch.utils.dlpack import from_dlpack, to_dlpack
  37. import triton_python_backend_utils as pb_utils
  38. from transformers import AutoTokenizer
  39. import torchaudio
  40. from matcha.utils.audio import mel_spectrogram
  41. from datetime import datetime
  42. ORIGINAL_VOCAB_SIZE = 151663
  43. torch.set_num_threads(1)
  44. def parse_speech_token_string(response_text: str) -> List[int]:
  45. """
  46. Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
  47. """
  48. speech_tokens = response_text.strip().split('><')
  49. if len(speech_tokens) > 1:
  50. # Add back the missing '<' and '>' for proper parsing
  51. speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
  52. speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
  53. speech_ids = []
  54. for token_str in speech_tokens:
  55. match = re.match(r'<\|s_(\d+)\|>', token_str)
  56. if match:
  57. speech_ids.append(int(match.group(1)))
  58. return speech_ids
  59. class TritonPythonModel:
  60. """Triton Python model for Spark TTS.
  61. This model orchestrates the end-to-end TTS pipeline by coordinating
  62. between audio tokenizer, LLM, and vocoder components.
  63. """
  64. def initialize(self, args):
  65. """Initialize the model.
  66. Args:
  67. args: Dictionary containing model configuration
  68. """
  69. self.logger = pb_utils.Logger
  70. # Parse model parameters
  71. self.model_config = json.loads(args['model_config'])
  72. parameters = self.model_config['parameters']
  73. model_params = {k: v["string_value"] for k, v in parameters.items()}
  74. self.logger.log_info(f"model_params:{model_params}")
  75. self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
  76. # self.dynamic_chunk_strategy = "equal"
  77. self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
  78. # Initialize tokenizer
  79. llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
  80. self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
  81. self.prompt_template = "<|sos|>{input_text}<|task_id|>"
  82. self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
  83. self.device = torch.device("cuda")
  84. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  85. self.token_frame_rate = 25
  86. self.flow_pre_lookahead_len = 3
  87. self.token_hop_len = 15
  88. spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
  89. if not os.path.exists(spk_info_path):
  90. raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
  91. spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
  92. self.default_spk_info = spk_info["001"]
  93. self.http_client = httpx.AsyncClient()
  94. def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
  95. """Converts a tensor or list of speech token IDs to a string representation."""
  96. if isinstance(speech_tokens, torch.Tensor):
  97. # Ensure tensor is on CPU and flattened
  98. speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
  99. speech_id_str = ""
  100. for token_id in speech_tokens:
  101. # Convert token ID back to the speech number N
  102. token_num = token_id - ORIGINAL_VOCAB_SIZE
  103. speech_id_str += f"<|s_{token_num}|>"
  104. return speech_id_str
  105. async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
  106. """
  107. Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
  108. """
  109. full_text = f"{reference_text}{target_text}"
  110. prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
  111. chat = [
  112. {"role": "user", "content": full_text},
  113. {"role": "assistant", "content": prompt_speech_tokens_str}
  114. ]
  115. payload = {
  116. "model": "trt_engines_bfloat16",
  117. "messages": chat,
  118. "max_tokens": 750,
  119. "temperature": 0.8,
  120. "top_p": 0.95,
  121. "top_k": 50,
  122. "repetition_penalty": 1.1,
  123. "stop": ["<|eos1|>", "<|eos|>"],
  124. "stream": True,
  125. }
  126. api_base = "http://localhost:8000/v1/chat/completions"
  127. buffer = ""
  128. async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
  129. print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
  130. print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
  131. response.raise_for_status()
  132. async for line in response.aiter_lines():
  133. if line.startswith("data: "):
  134. line_data = line[len("data: "):].strip()
  135. if line_data == "[DONE]":
  136. break
  137. try:
  138. json_data = json.loads(line_data)
  139. content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
  140. if content:
  141. buffer += content
  142. print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
  143. while True:
  144. match = re.search(r"<\|s_(\d+)\|>", buffer)
  145. if not match:
  146. break
  147. token_num = int(match.group(1))
  148. final_id = token_num + ORIGINAL_VOCAB_SIZE
  149. yield final_id
  150. buffer = buffer[match.end():]
  151. except json.JSONDecodeError:
  152. self.logger.log_info(f"Skipping non-JSON line: {line_data}")
  153. continue
  154. # Process any remaining complete tokens in the buffer after the stream ends
  155. while True:
  156. match = re.search(r"<\|s_(\d+)\|>", buffer)
  157. if not match:
  158. break
  159. token_num = int(match.group(1))
  160. final_id = token_num + ORIGINAL_VOCAB_SIZE
  161. yield final_id
  162. buffer = buffer[match.end():]
  163. def forward_audio_tokenizer(self, wav, wav_len):
  164. """Forward pass through the audio tokenizer component.
  165. Args:
  166. wav: Input waveform tensor
  167. wav_len: Waveform length tensor
  168. Returns:
  169. Tuple of global and semantic tokens
  170. """
  171. inference_request = pb_utils.InferenceRequest(
  172. model_name='audio_tokenizer',
  173. requested_output_names=['prompt_speech_tokens'],
  174. inputs=[wav, wav_len]
  175. )
  176. inference_response = inference_request.exec()
  177. if inference_response.has_error():
  178. raise pb_utils.TritonModelException(inference_response.error().message())
  179. # Extract and convert output tensors
  180. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
  181. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  182. return prompt_speech_tokens
  183. def forward_speaker_embedding(self, wav):
  184. """Forward pass through the speaker embedding component.
  185. Args:
  186. wav: Input waveform tensor
  187. Returns:
  188. Prompt speaker embedding tensor
  189. """
  190. inference_request = pb_utils.InferenceRequest(
  191. model_name='speaker_embedding',
  192. requested_output_names=['prompt_spk_embedding'],
  193. inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
  194. )
  195. inference_response = inference_request.exec()
  196. if inference_response.has_error():
  197. raise pb_utils.TritonModelException(inference_response.error().message())
  198. # Extract and convert output tensors
  199. prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
  200. prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
  201. return prompt_spk_embedding
  202. async def forward_token2wav(
  203. self,
  204. index: int,
  205. target_speech_tokens: torch.Tensor,
  206. request_id: str,
  207. reference_wav: object,
  208. reference_wav_len: object,
  209. finalize: bool = None) -> torch.Tensor:
  210. """Forward pass through the vocoder component.
  211. Args:
  212. prompt_speech_tokens: Prompt speech tokens tensor
  213. prompt_speech_feat: Prompt speech feat tensor
  214. prompt_spk_embedding: Prompt spk embedding tensor
  215. target_speech_tokens: Target speech tokens tensor
  216. Returns:
  217. Generated waveform tensor
  218. """
  219. target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
  220. finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
  221. inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
  222. # Create and execute inference request
  223. inference_request = pb_utils.InferenceRequest(
  224. model_name='token2wav_dit',
  225. requested_output_names=[
  226. "waveform",
  227. ],
  228. inputs=inputs_tensor,
  229. request_id=request_id,
  230. parameters={"priority": index+1},
  231. )
  232. inference_response = await inference_request.async_exec()
  233. if inference_response.has_error():
  234. raise pb_utils.TritonModelException(inference_response.error().message())
  235. # Extract and convert output waveform
  236. waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
  237. waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
  238. return waveform
  239. def _extract_speech_feat(self, speech):
  240. speech_feat = mel_spectrogram(
  241. speech,
  242. n_fft=1920,
  243. num_mels=80,
  244. sampling_rate=24000,
  245. hop_size=480,
  246. win_size=1920,
  247. fmin=0,
  248. fmax=8000).squeeze(
  249. dim=0).transpose(
  250. 0,
  251. 1).to(
  252. self.device)
  253. speech_feat = speech_feat.unsqueeze(dim=0)
  254. return speech_feat
  255. async def _process_request(self, request):
  256. request_id = request.request_id()
  257. # Extract input tensors
  258. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  259. # Process reference audio through audio tokenizer
  260. if wav is not None:
  261. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  262. prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
  263. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
  264. wav_tensor = wav.as_numpy()
  265. wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
  266. print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
  267. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
  268. speech_feat = self._extract_speech_feat(prompt_speech_resample)
  269. token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
  270. prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
  271. prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
  272. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
  273. reference_text = reference_text[0][0].decode('utf-8')
  274. # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
  275. # reference_text = self.default_spk_info["prompt_text"]
  276. # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
  277. # prompt_speech_feat = None
  278. # prompt_spk_embedding = None
  279. else:
  280. # using pre-cached reference text
  281. assert False, "using pre-cached reference text is not supported"
  282. reference_text = self.default_spk_info["prompt_text"]
  283. prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
  284. prompt_speech_feat = None
  285. prompt_spk_embedding = None
  286. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  287. target_text = target_text[0][0].decode('utf-8')
  288. print(f"target_text: {target_text}, time: {datetime.now()}")
  289. if self.decoupled:
  290. response_sender = request.get_response_sender()
  291. semantic_token_ids_arr = []
  292. token_offset, chunk_index = 0, 0
  293. start_time = time.time()
  294. this_token_hop_len = self.token_hop_len
  295. print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
  296. async for generated_ids in self.forward_llm_async(
  297. target_text=target_text,
  298. reference_text=reference_text,
  299. prompt_speech_tokens=prompt_speech_tokens,
  300. ):
  301. if not generated_ids:
  302. break
  303. semantic_token_ids_arr.append(generated_ids)
  304. print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
  305. while True:
  306. pending_num = len(semantic_token_ids_arr) - token_offset
  307. if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
  308. this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
  309. this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
  310. print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
  311. sub_tts_speech = await self.forward_token2wav(
  312. chunk_index,
  313. this_tts_speech_token, request_id, wav, wav_len, False
  314. )
  315. print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
  316. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  317. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  318. response_sender.send(inference_response)
  319. token_offset += this_token_hop_len
  320. self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
  321. if self.dynamic_chunk_strategy == "exponential":
  322. this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
  323. elif self.dynamic_chunk_strategy == "equal":
  324. this_token_hop_len = self.token_hop_len
  325. elif self.dynamic_chunk_strategy == "time_based":
  326. # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
  327. cost_time = time.time() - start_time
  328. duration = token_offset / self.token_frame_rate
  329. if chunk_index > 0 and cost_time > 0:
  330. avg_chunk_processing_time = cost_time / (chunk_index + 1)
  331. if avg_chunk_processing_time > 0:
  332. multiples = (duration - cost_time) / avg_chunk_processing_time
  333. self.logger.log_info(f"multiples: {multiples}")
  334. next_pending_num = len(semantic_token_ids_arr) - token_offset
  335. if multiples > 4:
  336. this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
  337. elif multiples > 2:
  338. this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
  339. else:
  340. this_token_hop_len = self.token_hop_len
  341. this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
  342. chunk_index += 1
  343. else:
  344. break
  345. this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
  346. sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
  347. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  348. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  349. response_sender.send(inference_response)
  350. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  351. self.logger.log_info("send tritonserver_response_complete_final to end")
  352. else:
  353. raise NotImplementedError("Decoupled mode is not supported")
  354. async def execute(self, requests):
  355. """Execute inference on the batched requests.
  356. Args:
  357. requests: List of inference requests
  358. Returns:
  359. List of inference responses containing generated audio
  360. """
  361. tasks = [
  362. asyncio.create_task(self._process_request(request))
  363. for request in requests
  364. ]
  365. await asyncio.gather(*tasks)
  366. return None
  367. def finalize(self):
  368. self.logger.log_info("Finalizing CosyVoice DIT model")
  369. if hasattr(self, "http_client"):
  370. asyncio.run(self.http_client.aclose())