model.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import json
  27. import math
  28. import os
  29. import re
  30. import time
  31. from typing import Dict, List, Tuple, Optional, Union
  32. import asyncio
  33. import httpx
  34. import numpy as np
  35. import torch
  36. from torch.utils.dlpack import from_dlpack, to_dlpack
  37. import triton_python_backend_utils as pb_utils
  38. from transformers import AutoTokenizer
  39. import torchaudio
  40. from matcha.utils.audio import mel_spectrogram
  41. from datetime import datetime
  42. ORIGINAL_VOCAB_SIZE = 151663
  43. torch.set_num_threads(1)
  44. def parse_speech_token_string(response_text: str) -> List[int]:
  45. """
  46. Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
  47. """
  48. speech_tokens = response_text.strip().split('><')
  49. if len(speech_tokens) > 1:
  50. # Add back the missing '<' and '>' for proper parsing
  51. speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
  52. speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
  53. speech_ids = []
  54. for token_str in speech_tokens:
  55. match = re.match(r'<\|s_(\d+)\|>', token_str)
  56. if match:
  57. speech_ids.append(int(match.group(1)))
  58. return speech_ids
  59. class TritonPythonModel:
  60. """Triton Python model for Spark TTS.
  61. This model orchestrates the end-to-end TTS pipeline by coordinating
  62. between audio tokenizer, LLM, and vocoder components.
  63. """
  64. def initialize(self, args):
  65. """Initialize the model.
  66. Args:
  67. args: Dictionary containing model configuration
  68. """
  69. self.logger = pb_utils.Logger
  70. # Parse model parameters
  71. self.model_config = json.loads(args['model_config'])
  72. parameters = self.model_config['parameters']
  73. model_params = {k: v["string_value"] for k, v in parameters.items()}
  74. self.logger.log_info(f"model_params:{model_params}")
  75. self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
  76. # self.dynamic_chunk_strategy = "equal"
  77. self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
  78. # Initialize tokenizer
  79. llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
  80. self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
  81. self.prompt_template = "<|sos|>{input_text}<|task_id|>"
  82. self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
  83. self.device = torch.device("cuda")
  84. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  85. self.token_frame_rate = 25
  86. self.flow_pre_lookahead_len = 3
  87. self.token_hop_len = 15
  88. spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
  89. if not os.path.exists(spk_info_path):
  90. raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
  91. spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
  92. self.default_spk_info = spk_info["001"]
  93. self.http_client = httpx.AsyncClient()
  94. self.runtime_cache = {}
  95. def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
  96. """Converts a tensor or list of speech token IDs to a string representation."""
  97. if isinstance(speech_tokens, torch.Tensor):
  98. # Ensure tensor is on CPU and flattened
  99. speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
  100. speech_id_str = ""
  101. for token_id in speech_tokens:
  102. # Convert token ID back to the speech number N
  103. token_num = token_id - ORIGINAL_VOCAB_SIZE
  104. speech_id_str += f"<|s_{token_num}|>"
  105. return speech_id_str
  106. async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
  107. """
  108. Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
  109. """
  110. full_text = f"{reference_text}{target_text}"
  111. prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
  112. chat = [
  113. {"role": "user", "content": full_text},
  114. {"role": "assistant", "content": prompt_speech_tokens_str}
  115. ]
  116. payload = {
  117. "model": "trt_engines_bfloat16",
  118. "messages": chat,
  119. "max_tokens": 750,
  120. "temperature": 0.8,
  121. "top_p": 0.95,
  122. "top_k": 50,
  123. "repetition_penalty": 1.1,
  124. "stop": ["<|eos1|>", "<|eos|>"],
  125. "stream": True,
  126. }
  127. api_base = "http://localhost:8000/v1/chat/completions"
  128. buffer = ""
  129. async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
  130. print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
  131. print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
  132. response.raise_for_status()
  133. async for line in response.aiter_lines():
  134. if line.startswith("data: "):
  135. line_data = line[len("data: "):].strip()
  136. if line_data == "[DONE]":
  137. break
  138. try:
  139. json_data = json.loads(line_data)
  140. content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
  141. if content:
  142. buffer += content
  143. print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
  144. while True:
  145. match = re.search(r"<\|s_(\d+)\|>", buffer)
  146. if not match:
  147. break
  148. token_num = int(match.group(1))
  149. final_id = token_num + ORIGINAL_VOCAB_SIZE
  150. yield final_id
  151. buffer = buffer[match.end():]
  152. except json.JSONDecodeError:
  153. self.logger.log_info(f"Skipping non-JSON line: {line_data}")
  154. continue
  155. # Process any remaining complete tokens in the buffer after the stream ends
  156. while True:
  157. match = re.search(r"<\|s_(\d+)\|>", buffer)
  158. if not match:
  159. break
  160. token_num = int(match.group(1))
  161. final_id = token_num + ORIGINAL_VOCAB_SIZE
  162. yield final_id
  163. buffer = buffer[match.end():]
  164. def forward_audio_tokenizer(self, wav, wav_len):
  165. """Forward pass through the audio tokenizer component.
  166. Args:
  167. wav: Input waveform tensor
  168. wav_len: Waveform length tensor
  169. Returns:
  170. Tuple of global and semantic tokens
  171. """
  172. inference_request = pb_utils.InferenceRequest(
  173. model_name='audio_tokenizer',
  174. requested_output_names=['prompt_speech_tokens'],
  175. inputs=[wav, wav_len]
  176. )
  177. inference_response = inference_request.exec()
  178. if inference_response.has_error():
  179. raise pb_utils.TritonModelException(inference_response.error().message())
  180. # Extract and convert output tensors
  181. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
  182. prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  183. return prompt_speech_tokens
  184. def forward_speaker_embedding(self, wav):
  185. """Forward pass through the speaker embedding component.
  186. Args:
  187. wav: Input waveform tensor
  188. Returns:
  189. Prompt speaker embedding tensor
  190. """
  191. inference_request = pb_utils.InferenceRequest(
  192. model_name='speaker_embedding',
  193. requested_output_names=['prompt_spk_embedding'],
  194. inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
  195. )
  196. inference_response = inference_request.exec()
  197. if inference_response.has_error():
  198. raise pb_utils.TritonModelException(inference_response.error().message())
  199. # Extract and convert output tensors
  200. prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
  201. prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
  202. return prompt_spk_embedding
  203. async def forward_token2wav(
  204. self,
  205. index: int,
  206. target_speech_tokens: torch.Tensor,
  207. request_id: str,
  208. reference_wav: object,
  209. reference_wav_len: object,
  210. finalize: bool = None) -> torch.Tensor:
  211. """Forward pass through the vocoder component.
  212. Args:
  213. prompt_speech_tokens: Prompt speech tokens tensor
  214. prompt_speech_feat: Prompt speech feat tensor
  215. prompt_spk_embedding: Prompt spk embedding tensor
  216. target_speech_tokens: Target speech tokens tensor
  217. Returns:
  218. Generated waveform tensor
  219. """
  220. target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
  221. finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
  222. inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
  223. # optional cache inputs
  224. if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None:
  225. # inputs_tensor.extend([
  226. # pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()),
  227. # pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()),
  228. # pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()),
  229. # pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()),
  230. # pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()),
  231. # pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()),
  232. # pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()),
  233. # ])
  234. inputs_tensor.extend([
  235. self.runtime_cache[request_id]["conformer_cnn_cache"],
  236. self.runtime_cache[request_id]["conformer_att_cache"],
  237. self.runtime_cache[request_id]["estimator_cnn_cache"],
  238. self.runtime_cache[request_id]["estimator_att_cache"],
  239. self.runtime_cache[request_id]["mel"],
  240. self.runtime_cache[request_id]["source"],
  241. self.runtime_cache[request_id]["speech"],
  242. ])
  243. # Create and execute inference request
  244. inference_request = pb_utils.InferenceRequest(
  245. model_name='token2wav_dit',
  246. requested_output_names=[
  247. "waveform",
  248. "conformer_cnn_cache",
  249. "conformer_att_cache",
  250. "estimator_cnn_cache",
  251. "estimator_att_cache",
  252. "mel",
  253. "source",
  254. "speech",
  255. ],
  256. inputs=inputs_tensor,
  257. request_id=request_id,
  258. parameters={"priority": index+1},
  259. )
  260. inference_response = await inference_request.async_exec()
  261. if inference_response.has_error():
  262. raise pb_utils.TritonModelException(inference_response.error().message())
  263. self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache")
  264. self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache")
  265. self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache")
  266. self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache")
  267. self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel")
  268. self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source")
  269. self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech")
  270. # Extract and convert output waveform
  271. waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
  272. waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
  273. return waveform
  274. def _extract_speech_feat(self, speech):
  275. speech_feat = mel_spectrogram(
  276. speech,
  277. n_fft=1920,
  278. num_mels=80,
  279. sampling_rate=24000,
  280. hop_size=480,
  281. win_size=1920,
  282. fmin=0,
  283. fmax=8000).squeeze(
  284. dim=0).transpose(
  285. 0,
  286. 1).to(
  287. self.device)
  288. speech_feat = speech_feat.unsqueeze(dim=0)
  289. return speech_feat
  290. async def _process_request(self, request):
  291. request_id = request.request_id()
  292. if request_id not in self.runtime_cache:
  293. self.runtime_cache[request_id] = {
  294. "conformer_cnn_cache": None,
  295. "conformer_att_cache": None,
  296. "estimator_cnn_cache": None,
  297. "estimator_att_cache": None,
  298. "mel": None,
  299. "source": None,
  300. "speech": None,
  301. }
  302. # Extract input tensors
  303. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  304. # Process reference audio through audio tokenizer
  305. if wav is not None:
  306. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  307. prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
  308. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
  309. wav_tensor = wav.as_numpy()
  310. wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
  311. print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
  312. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
  313. speech_feat = self._extract_speech_feat(prompt_speech_resample)
  314. token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
  315. prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
  316. prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
  317. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
  318. reference_text = reference_text[0][0].decode('utf-8')
  319. prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
  320. # reference_text = self.default_spk_info["prompt_text"]
  321. # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
  322. # prompt_speech_feat = None
  323. # prompt_spk_embedding = None
  324. else:
  325. # using pre-cached reference text
  326. assert False, "using pre-cached reference text is not supported"
  327. reference_text = self.default_spk_info["prompt_text"]
  328. prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
  329. prompt_speech_feat = None
  330. prompt_spk_embedding = None
  331. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  332. target_text = target_text[0][0].decode('utf-8')
  333. print(f"target_text: {target_text}, time: {datetime.now()}")
  334. if self.decoupled:
  335. response_sender = request.get_response_sender()
  336. semantic_token_ids_arr = []
  337. token_offset, chunk_index = 0, 0
  338. start_time = time.time()
  339. this_token_hop_len = self.token_hop_len
  340. print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
  341. async for generated_ids in self.forward_llm_async(
  342. target_text=target_text,
  343. reference_text=reference_text,
  344. prompt_speech_tokens=prompt_speech_tokens,
  345. ):
  346. if not generated_ids:
  347. break
  348. semantic_token_ids_arr.append(generated_ids)
  349. print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
  350. while True:
  351. pending_num = len(semantic_token_ids_arr) - token_offset
  352. if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
  353. this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
  354. this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
  355. print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
  356. sub_tts_speech = await self.forward_token2wav(
  357. chunk_index,
  358. this_tts_speech_token, request_id, wav, wav_len, False
  359. )
  360. print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
  361. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  362. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  363. response_sender.send(inference_response)
  364. token_offset += this_token_hop_len
  365. self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
  366. if self.dynamic_chunk_strategy == "exponential":
  367. this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
  368. elif self.dynamic_chunk_strategy == "equal":
  369. this_token_hop_len = self.token_hop_len
  370. elif self.dynamic_chunk_strategy == "time_based":
  371. # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
  372. cost_time = time.time() - start_time
  373. duration = token_offset / self.token_frame_rate
  374. if chunk_index > 0 and cost_time > 0:
  375. avg_chunk_processing_time = cost_time / (chunk_index + 1)
  376. if avg_chunk_processing_time > 0:
  377. multiples = (duration - cost_time) / avg_chunk_processing_time
  378. self.logger.log_info(f"multiples: {multiples}")
  379. next_pending_num = len(semantic_token_ids_arr) - token_offset
  380. if multiples > 4:
  381. this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
  382. elif multiples > 2:
  383. this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
  384. else:
  385. this_token_hop_len = self.token_hop_len
  386. this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
  387. chunk_index += 1
  388. else:
  389. break
  390. this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
  391. sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
  392. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
  393. inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  394. response_sender.send(inference_response)
  395. if request_id in self.runtime_cache:
  396. del self.runtime_cache[request_id]
  397. self.logger.log_info(f"Deleted cache for request_id: {request_id}")
  398. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  399. self.logger.log_info("send tritonserver_response_complete_final to end")
  400. else:
  401. raise NotImplementedError("Decoupled mode is not supported")
  402. async def execute(self, requests):
  403. """Execute inference on the batched requests.
  404. Args:
  405. requests: List of inference requests
  406. Returns:
  407. List of inference responses containing generated audio
  408. """
  409. tasks = [
  410. asyncio.create_task(self._process_request(request))
  411. for request in requests
  412. ]
  413. await asyncio.gather(*tasks)
  414. return None
  415. def finalize(self):
  416. self.logger.log_info("Finalizing CosyVoice DIT model")
  417. if hasattr(self, "http_client"):
  418. asyncio.run(self.http_client.aclose())