model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. import json
  2. import re
  3. import time
  4. import asyncio
  5. import numpy as np
  6. import torch
  7. from torch.utils.dlpack import to_dlpack
  8. import triton_python_backend_utils as pb_utils
  9. import httpx
  10. import torchaudio
  11. from functools import partial
  12. from matcha.utils.audio import mel_spectrogram as matcha_mel_spectrogram
  13. torch.set_num_threads(1)
  14. # CosyVoice3 mel params: fmax=None (Nyquist), center=False
  15. mel_spectrogram = partial(matcha_mel_spectrogram,
  16. n_fft=1920, num_mels=80, sampling_rate=24000,
  17. hop_size=480, win_size=1920, fmin=0, fmax=None, center=False)
  18. def parse_speech_token_string(response_text):
  19. """Parse speech tokens from string like '<|s_123|><|s_456|>' into list of int IDs."""
  20. speech_tokens = response_text.strip().split('><')
  21. if len(speech_tokens) > 1:
  22. speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
  23. speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
  24. speech_ids = []
  25. for token_str in speech_tokens:
  26. match = re.match(r'<\|s_(\d+)\|>', token_str)
  27. if match:
  28. speech_ids.append(int(match.group(1)))
  29. return speech_ids
  30. class TritonPythonModel:
  31. """CosyVoice3 BLS orchestrator for Triton Inference Server.
  32. Orchestrates: audio_tokenizer, speaker_embedding, remote LLM (httpx),
  33. token2wav (flow-only), and vocoder (CausalHiFTGenerator).
  34. Supports both streaming (decoupled) and offline (non-decoupled) modes.
  35. """
  36. def initialize(self, args):
  37. self.logger = pb_utils.Logger
  38. self.model_config = json.loads(args['model_config'])
  39. parameters = self.model_config['parameters']
  40. model_params = {k: v["string_value"] for k, v in parameters.items()}
  41. self.device = torch.device("cuda")
  42. self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
  43. # Streaming config
  44. self.token_frame_rate = 25
  45. self.flow_pre_lookahead_len = 3
  46. self.token_hop_len = 15
  47. self.token_mel_ratio = 2
  48. self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")
  49. self.logger.log_info(f"CosyVoice3 BLS initialized, decoupled={self.decoupled}, "
  50. f"chunk_strategy={self.dynamic_chunk_strategy}")
  51. # HTTP client for remote LLM (trtllm-serve default port: 8000)
  52. self.http_client = httpx.AsyncClient()
  53. self.api_base = model_params.get("llm_api_base", "http://localhost:8000/v1/chat/completions")
  54. # Speaker cache to avoid redundant audio_tokenizer/speaker_embedding calls
  55. self.speaker_cache = {}
  56. def _convert_speech_tokens_to_str(self, speech_tokens):
  57. """Convert speech token IDs tensor/list to string like '<|s_N|>'."""
  58. if isinstance(speech_tokens, torch.Tensor):
  59. speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
  60. return "".join(f"<|s_{int(tid)}|>" for tid in speech_tokens)
  61. def _extract_speech_feat(self, speech):
  62. """Extract mel spectrogram from 24kHz speech for flow prompt."""
  63. speech_feat = mel_spectrogram(speech).squeeze(dim=0).transpose(0, 1)
  64. speech_feat = speech_feat.unsqueeze(dim=0).to(self.device)
  65. return speech_feat
  66. async def forward_llm_streaming(self, target_text, reference_text, prompt_speech_tokens):
  67. """Async generator: stream LLM tokens via httpx SSE."""
  68. full_text = f"{reference_text}{target_text}"
  69. prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
  70. chat = [
  71. {"role": "user", "content": full_text},
  72. {"role": "assistant", "content": prompt_speech_tokens_str}
  73. ]
  74. payload = {
  75. "model": "trt_engines_bfloat16",
  76. "messages": chat,
  77. "max_tokens": 750,
  78. "temperature": 0.8,
  79. "top_p": 0.95,
  80. "top_k": 50,
  81. "repetition_penalty": 1.1,
  82. "stop": ["<|eos1|>", "<|eos|>"],
  83. "stream": True,
  84. }
  85. buffer = ""
  86. async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
  87. response.raise_for_status()
  88. async for line in response.aiter_lines():
  89. if line.startswith("data: "):
  90. line_data = line[len("data: "):].strip()
  91. if line_data == "[DONE]":
  92. break
  93. try:
  94. json_data = json.loads(line_data)
  95. content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
  96. if content:
  97. buffer += content
  98. while True:
  99. match = re.search(r"<\|s_(\d+)\|>", buffer)
  100. if not match:
  101. break
  102. token_num = int(match.group(1))
  103. # final_id = token_num + ORIGINAL_VOCAB_SIZE
  104. yield token_num
  105. buffer = buffer[match.end():]
  106. except json.JSONDecodeError:
  107. continue
  108. # Flush remaining tokens
  109. while True:
  110. match = re.search(r"<\|s_(\d+)\|>", buffer)
  111. if not match:
  112. break
  113. token_num = int(match.group(1))
  114. #final_id = token_num + ORIGINAL_VOCAB_SIZE
  115. yield token_num
  116. buffer = buffer[match.end():]
  117. async def forward_llm_offline(self, target_text, reference_text, prompt_speech_tokens):
  118. """Non-streaming LLM call, returns all speech token IDs at once."""
  119. full_text = f"{reference_text}{target_text}"
  120. prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
  121. chat = [
  122. {"role": "user", "content": full_text},
  123. {"role": "assistant", "content": prompt_speech_tokens_str}
  124. ]
  125. payload = {
  126. "model": "trt_engines_bfloat16",
  127. "messages": chat,
  128. "max_tokens": 750,
  129. "temperature": 0.8,
  130. "top_p": 0.95,
  131. "top_k": 50,
  132. "repetition_penalty": 1.1,
  133. "stop": ["<|eos1|>", "<|eos|>"],
  134. "stream": False,
  135. }
  136. response = await self.http_client.post(self.api_base, json=payload, timeout=None)
  137. response.raise_for_status()
  138. response_json = response.json()
  139. generated_content = response_json['choices'][0]['message']['content']
  140. speech_ids = parse_speech_token_string(generated_content)
  141. # return [sid + ORIGINAL_VOCAB_SIZE for sid in speech_ids]
  142. return speech_ids
  143. def forward_audio_tokenizer(self, wav, wav_len):
  144. """BLS call to audio_tokenizer."""
  145. inference_request = pb_utils.InferenceRequest(
  146. model_name='audio_tokenizer',
  147. requested_output_names=['prompt_speech_tokens'],
  148. inputs=[wav, wav_len]
  149. )
  150. inference_response = inference_request.exec()
  151. if inference_response.has_error():
  152. raise pb_utils.TritonModelException(inference_response.error().message())
  153. prompt_speech_tokens = pb_utils.get_output_tensor_by_name(
  154. inference_response, 'prompt_speech_tokens')
  155. return torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
  156. def forward_speaker_embedding(self, wav):
  157. """BLS call to speaker_embedding."""
  158. inference_request = pb_utils.InferenceRequest(
  159. model_name='speaker_embedding',
  160. requested_output_names=['prompt_spk_embedding'],
  161. inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
  162. )
  163. inference_response = inference_request.exec()
  164. if inference_response.has_error():
  165. raise pb_utils.TritonModelException(inference_response.error().message())
  166. prompt_spk_embedding = pb_utils.get_output_tensor_by_name(
  167. inference_response, 'prompt_spk_embedding')
  168. return torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
  169. async def forward_token2wav(self, target_speech_tokens, prompt_speech_tokens,
  170. prompt_speech_feat, prompt_spk_embedding,
  171. request_id, token_offset=None, finalize=True,
  172. priority=100):
  173. """Async BLS call to token2wav (flow-only). Returns mel tensor."""
  174. target_tokens_pb = pb_utils.Tensor.from_dlpack(
  175. "target_speech_tokens", to_dlpack(target_speech_tokens))
  176. prompt_tokens_pb = pb_utils.Tensor.from_dlpack(
  177. "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
  178. prompt_feat_pb = pb_utils.Tensor.from_dlpack(
  179. "prompt_speech_feat", to_dlpack(prompt_speech_feat))
  180. prompt_emb_pb = pb_utils.Tensor.from_dlpack(
  181. "prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
  182. inputs = [target_tokens_pb, prompt_tokens_pb, prompt_feat_pb, prompt_emb_pb]
  183. if token_offset is not None:
  184. inputs.append(pb_utils.Tensor("token_offset",
  185. np.array([[token_offset]], dtype=np.int32)))
  186. inputs.append(pb_utils.Tensor("finalize",
  187. np.array([[finalize]], dtype=np.bool_)))
  188. inference_request = pb_utils.InferenceRequest(
  189. model_name='token2wav',
  190. requested_output_names=['mel'],
  191. inputs=inputs,
  192. request_id=request_id,
  193. parameters={"priority": priority},
  194. )
  195. inference_response = await inference_request.async_exec()
  196. if inference_response.has_error():
  197. raise pb_utils.TritonModelException(inference_response.error().message())
  198. mel = pb_utils.get_output_tensor_by_name(inference_response, 'mel')
  199. return torch.utils.dlpack.from_dlpack(mel.to_dlpack())
  200. async def forward_vocoder(self, mel, finalize):
  201. """Async BLS call to vocoder. Returns speech tensor."""
  202. if mel.dim() == 2:
  203. mel = mel.unsqueeze(0) # [80, T] -> [1, 80, T]
  204. mel_pb = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel.float()))
  205. finalize_pb = pb_utils.Tensor("finalize",
  206. np.array([[finalize]], dtype=np.bool_))
  207. inference_request = pb_utils.InferenceRequest(
  208. model_name='vocoder',
  209. requested_output_names=['tts_speech'],
  210. inputs=[mel_pb, finalize_pb],
  211. )
  212. inference_response = await inference_request.async_exec()
  213. if inference_response.has_error():
  214. raise pb_utils.TritonModelException(inference_response.error().message())
  215. speech = pb_utils.get_output_tensor_by_name(inference_response, 'tts_speech')
  216. return torch.utils.dlpack.from_dlpack(speech.to_dlpack()).cpu()
  217. def _prepare_prompt(self, request):
  218. """Extract reference audio, tokenize, compute speaker embedding and mel feat."""
  219. wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
  220. wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
  221. reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text")
  222. reference_text = reference_text.as_numpy()[0][0].decode('utf-8') if reference_text is not None else ""
  223. if '<|endofprompt|>' not in reference_text:
  224. reference_text = 'You are a helpful assistant.<|endofprompt|>' + reference_text
  225. # Check speaker cache
  226. if reference_text in self.speaker_cache:
  227. cached = self.speaker_cache[reference_text]
  228. return (cached['prompt_speech_tokens_for_llm'], cached['prompt_speech_tokens'],
  229. cached['prompt_speech_feat'], cached['prompt_spk_embedding'], reference_text)
  230. # Audio tokenizer
  231. wav_np = wav.as_numpy()
  232. wav_len_val = wav_len.as_numpy()[0][0]
  233. prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
  234. prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) # [1, T]
  235. # Speaker embedding
  236. wav_tensor = torch.from_numpy(wav_np)
  237. wav_tensor = wav_tensor[:, :wav_len_val]
  238. prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
  239. # Mel extraction at 24kHz with CosyVoice3 params
  240. prompt_speech_resample = torchaudio.transforms.Resample(
  241. orig_freq=16000, new_freq=24000)(wav_tensor)
  242. speech_feat = self._extract_speech_feat(prompt_speech_resample)
  243. # Keep full tokens for LLM prefill (untruncated)
  244. prompt_speech_tokens_for_llm = prompt_speech_tokens.clone()
  245. # Align prompt speech feat and tokens to 2:1 ratio (for flow model only)
  246. orig_feat_len = speech_feat.shape[1]
  247. orig_token_len = prompt_speech_tokens.shape[-1]
  248. token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
  249. prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
  250. prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
  251. # Cache
  252. self.speaker_cache[reference_text] = {
  253. 'prompt_speech_tokens_for_llm': prompt_speech_tokens_for_llm,
  254. 'prompt_speech_tokens': prompt_speech_tokens,
  255. 'prompt_speech_feat': prompt_speech_feat,
  256. 'prompt_spk_embedding': prompt_spk_embedding,
  257. }
  258. return prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, reference_text
  259. async def _process_request_streaming(self, request):
  260. """Process a single request in streaming (decoupled) mode."""
  261. request_id = request.request_id()
  262. response_sender = request.get_response_sender()
  263. try:
  264. prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, \
  265. prompt_spk_embedding, reference_text = self._prepare_prompt(request)
  266. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  267. target_text = target_text[0][0].decode('utf-8')
  268. semantic_token_ids_arr = []
  269. token_offset = 0
  270. chunk_index = 0
  271. this_token_hop_len = self.token_hop_len
  272. accumulated_mel = None
  273. speech_offset = 0
  274. start_time = time.time()
  275. async for generated_id in self.forward_llm_streaming(
  276. target_text=target_text,
  277. reference_text=reference_text,
  278. prompt_speech_tokens=prompt_speech_tokens_for_llm,
  279. ):
  280. semantic_token_ids_arr.append(generated_id)
  281. while True:
  282. pending_num = len(semantic_token_ids_arr) - token_offset
  283. if pending_num < this_token_hop_len + self.flow_pre_lookahead_len:
  284. break
  285. # Prepare tokens for this chunk
  286. end_idx = token_offset + this_token_hop_len + self.flow_pre_lookahead_len
  287. this_tokens = torch.tensor(
  288. semantic_token_ids_arr[:end_idx]
  289. ).unsqueeze(0).to(torch.int32).to(self.device)
  290. # Call token2wav (flow-only) -> mel_chunk
  291. mel_chunk = await self.forward_token2wav(
  292. this_tokens, prompt_speech_tokens,
  293. prompt_speech_feat, prompt_spk_embedding,
  294. request_id, token_offset=token_offset, finalize=False,
  295. priority=chunk_index + 1,
  296. )
  297. # Accumulate mel
  298. if mel_chunk.dim() == 2:
  299. mel_chunk = mel_chunk.unsqueeze(0)
  300. if accumulated_mel is None:
  301. accumulated_mel = mel_chunk
  302. else:
  303. accumulated_mel = torch.cat([accumulated_mel, mel_chunk], dim=2)
  304. # Call vocoder
  305. speech = await self.forward_vocoder(accumulated_mel, finalize=False)
  306. # Extract new speech
  307. new_speech = speech[:, speech_offset:]
  308. speech_offset += new_speech.shape[1]
  309. if new_speech.shape[1] > 0:
  310. audio_tensor = pb_utils.Tensor.from_dlpack(
  311. "waveform", to_dlpack(new_speech))
  312. inference_response = pb_utils.InferenceResponse(
  313. output_tensors=[audio_tensor])
  314. response_sender.send(inference_response)
  315. token_offset += this_token_hop_len
  316. # Dynamic chunk strategy
  317. if self.dynamic_chunk_strategy == "exponential":
  318. this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
  319. elif self.dynamic_chunk_strategy == "time_based":
  320. cost_time = time.time() - start_time
  321. duration = token_offset / self.token_frame_rate
  322. if chunk_index > 0 and cost_time > 0:
  323. avg_chunk_time = cost_time / (chunk_index + 1)
  324. if avg_chunk_time > 0:
  325. multiples = (duration - cost_time) / avg_chunk_time
  326. next_pending = len(semantic_token_ids_arr) - token_offset
  327. if multiples > 4:
  328. this_token_hop_len = (next_pending // self.token_hop_len + 1) * self.token_hop_len
  329. elif multiples > 2:
  330. this_token_hop_len = (next_pending // self.token_hop_len) * self.token_hop_len
  331. else:
  332. this_token_hop_len = self.token_hop_len
  333. this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
  334. chunk_index += 1
  335. # Final chunk with remaining tokens
  336. if len(semantic_token_ids_arr) > 0:
  337. remaining_tokens = torch.tensor(
  338. semantic_token_ids_arr
  339. ).unsqueeze(0).to(torch.int32).to(self.device)
  340. mel_chunk = await self.forward_token2wav(
  341. remaining_tokens, prompt_speech_tokens,
  342. prompt_speech_feat, prompt_spk_embedding,
  343. request_id, token_offset=token_offset, finalize=True,
  344. priority=chunk_index + 1,
  345. )
  346. if mel_chunk.dim() == 2:
  347. mel_chunk = mel_chunk.unsqueeze(0)
  348. if accumulated_mel is None:
  349. accumulated_mel = mel_chunk
  350. else:
  351. accumulated_mel = torch.cat([accumulated_mel, mel_chunk], dim=2)
  352. speech = await self.forward_vocoder(accumulated_mel, finalize=True)
  353. new_speech = speech[:, speech_offset:]
  354. if new_speech.shape[1] > 0:
  355. audio_tensor = pb_utils.Tensor.from_dlpack(
  356. "waveform", to_dlpack(new_speech))
  357. inference_response = pb_utils.InferenceResponse(
  358. output_tensors=[audio_tensor])
  359. response_sender.send(inference_response)
  360. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  361. except Exception as e:
  362. self.logger.log_error(f"Error in streaming request: {e}")
  363. error_response = pb_utils.InferenceResponse(
  364. error=pb_utils.TritonError(str(e)))
  365. response_sender.send(error_response)
  366. response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
  367. async def _process_request_offline(self, request):
  368. """Process a single request in offline (non-decoupled) mode."""
  369. request_id = request.request_id()
  370. prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, \
  371. prompt_spk_embedding, reference_text = self._prepare_prompt(request)
  372. target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
  373. target_text = target_text[0][0].decode('utf-8')
  374. # Get all speech tokens at once (use full untruncated prompt tokens for LLM)
  375. all_token_ids = await self.forward_llm_offline(
  376. target_text=target_text,
  377. reference_text=reference_text,
  378. prompt_speech_tokens=prompt_speech_tokens_for_llm,
  379. )
  380. if len(all_token_ids) == 0:
  381. raise pb_utils.TritonModelException("LLM generated no speech tokens")
  382. all_tokens = torch.tensor(all_token_ids).unsqueeze(0).to(torch.int32).to(self.device)
  383. # token2wav (no token_offset, finalize=True) -> full mel
  384. mel = await self.forward_token2wav(
  385. all_tokens, prompt_speech_tokens,
  386. prompt_speech_feat, prompt_spk_embedding,
  387. request_id,
  388. )
  389. # vocoder -> full speech
  390. speech = await self.forward_vocoder(mel, finalize=True)
  391. audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(speech))
  392. return pb_utils.InferenceResponse(output_tensors=[audio_tensor])
  393. async def execute(self, requests):
  394. if self.decoupled:
  395. tasks = [
  396. asyncio.create_task(self._process_request_streaming(request))
  397. for request in requests
  398. ]
  399. await asyncio.gather(*tasks)
  400. return None
  401. else:
  402. responses = []
  403. for request in requests:
  404. try:
  405. response = await self._process_request_offline(request)
  406. responses.append(response)
  407. except Exception as e:
  408. self.logger.log_error(f"Error in offline request: {e}")
  409. responses.append(pb_utils.InferenceResponse(
  410. error=pb_utils.TritonError(str(e))))
  411. return responses
  412. def finalize(self):
  413. self.logger.log_info("Finalizing CosyVoice3 BLS model")
  414. if hasattr(self, "http_client"):
  415. asyncio.run(self.http_client.aclose())