# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json import math import os import re import time from typing import Dict, List, Tuple, Optional, Union import asyncio import httpx import numpy as np import torch from torch.utils.dlpack import from_dlpack, to_dlpack import triton_python_backend_utils as pb_utils from transformers import AutoTokenizer import torchaudio from matcha.utils.audio import mel_spectrogram from datetime import datetime ORIGINAL_VOCAB_SIZE = 151663 torch.set_num_threads(1) def parse_speech_token_string(response_text: str) -> List[int]: """ Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs. """ speech_tokens = response_text.strip().split('><') if len(speech_tokens) > 1: # Add back the missing '<' and '>' for proper parsing speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens] speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens] speech_ids = [] for token_str in speech_tokens: match = re.match(r'<\|s_(\d+)\|>', token_str) if match: speech_ids.append(int(match.group(1))) return speech_ids class TritonPythonModel: """Triton Python model for Spark TTS. This model orchestrates the end-to-end TTS pipeline by coordinating between audio tokenizer, LLM, and vocoder components. """ def initialize(self, args): """Initialize the model. Args: args: Dictionary containing model configuration """ self.logger = pb_utils.Logger # Parse model parameters self.model_config = json.loads(args['model_config']) parameters = self.model_config['parameters'] model_params = {k: v["string_value"] for k, v in parameters.items()} self.logger.log_info(f"model_params:{model_params}") self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based" # self.dynamic_chunk_strategy = "equal" self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}") # Initialize tokenizer llm_tokenizer_dir = model_params["llm_tokenizer_dir"] self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) self.prompt_template = "<|sos|>{input_text}<|task_id|>" self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>") self.device = torch.device("cuda") self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config) self.token_frame_rate = 25 self.flow_pre_lookahead_len = 3 self.token_hop_len = 15 spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt") if not os.path.exists(spk_info_path): raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}") spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False) self.default_spk_info = spk_info["001"] self.http_client = httpx.AsyncClient() self.runtime_cache = {} def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str: """Converts a tensor or list of speech token IDs to a string representation.""" if isinstance(speech_tokens, torch.Tensor): # Ensure tensor is on CPU and flattened speech_tokens = speech_tokens.cpu().numpy().flatten().tolist() speech_id_str = "" for token_id in speech_tokens: # Convert token ID back to the speech number N token_num = token_id - ORIGINAL_VOCAB_SIZE speech_id_str += f"<|s_{token_num}|>" return speech_id_str async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]): """ Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response. """ full_text = f"{reference_text}{target_text}" prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens) chat = [ {"role": "user", "content": full_text}, {"role": "assistant", "content": prompt_speech_tokens_str} ] payload = { "model": "trt_engines_bfloat16", "messages": chat, "max_tokens": 750, "temperature": 0.8, "top_p": 0.95, "top_k": 50, "repetition_penalty": 1.1, "stop": ["<|eos1|>", "<|eos|>"], "stream": True, } api_base = "http://localhost:8000/v1/chat/completions" buffer = "" async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response: print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}") print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}") response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): line_data = line[len("data: "):].strip() if line_data == "[DONE]": break try: json_data = json.loads(line_data) content = json_data.get("choices", [{}])[0].get("delta", {}).get("content") if content: buffer += content print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: match = re.search(r"<\|s_(\d+)\|>", buffer) if not match: break token_num = int(match.group(1)) final_id = token_num + ORIGINAL_VOCAB_SIZE yield final_id buffer = buffer[match.end():] except json.JSONDecodeError: self.logger.log_info(f"Skipping non-JSON line: {line_data}") continue # Process any remaining complete tokens in the buffer after the stream ends while True: match = re.search(r"<\|s_(\d+)\|>", buffer) if not match: break token_num = int(match.group(1)) final_id = token_num + ORIGINAL_VOCAB_SIZE yield final_id buffer = buffer[match.end():] def forward_audio_tokenizer(self, wav, wav_len): """Forward pass through the audio tokenizer component. Args: wav: Input waveform tensor wav_len: Waveform length tensor Returns: Tuple of global and semantic tokens """ inference_request = pb_utils.InferenceRequest( model_name='audio_tokenizer', requested_output_names=['prompt_speech_tokens'], inputs=[wav, wav_len] ) inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) # Extract and convert output tensors prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens') prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu() return prompt_speech_tokens def forward_speaker_embedding(self, wav): """Forward pass through the speaker embedding component. Args: wav: Input waveform tensor Returns: Prompt speaker embedding tensor """ inference_request = pb_utils.InferenceRequest( model_name='speaker_embedding', requested_output_names=['prompt_spk_embedding'], inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))] ) inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) # Extract and convert output tensors prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding') prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack()) return prompt_spk_embedding async def forward_token2wav( self, index: int, target_speech_tokens: torch.Tensor, request_id: str, reference_wav: object, reference_wav_len: object, finalize: bool = None) -> torch.Tensor: """Forward pass through the vocoder component. Args: prompt_speech_tokens: Prompt speech tokens tensor prompt_speech_feat: Prompt speech feat tensor prompt_spk_embedding: Prompt spk embedding tensor target_speech_tokens: Target speech tokens tensor Returns: Generated waveform tensor """ target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens)) finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_)) inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor] # optional cache inputs if self.runtime_cache[request_id]["conformer_cnn_cache"] is not None: # inputs_tensor.extend([ # pb_utils.Tensor("conformer_cnn_cache", self.runtime_cache[request_id]["conformer_cnn_cache"].as_numpy()), # pb_utils.Tensor("conformer_att_cache", self.runtime_cache[request_id]["conformer_att_cache"].as_numpy()), # pb_utils.Tensor("estimator_cnn_cache", self.runtime_cache[request_id]["estimator_cnn_cache"].as_numpy()), # pb_utils.Tensor("estimator_att_cache", self.runtime_cache[request_id]["estimator_att_cache"].as_numpy()), # pb_utils.Tensor("mel", self.runtime_cache[request_id]["mel"].as_numpy()), # pb_utils.Tensor("source", self.runtime_cache[request_id]["source"].as_numpy()), # pb_utils.Tensor("speech", self.runtime_cache[request_id]["speech"].as_numpy()), # ]) inputs_tensor.extend([ self.runtime_cache[request_id]["conformer_cnn_cache"], self.runtime_cache[request_id]["conformer_att_cache"], self.runtime_cache[request_id]["estimator_cnn_cache"], self.runtime_cache[request_id]["estimator_att_cache"], self.runtime_cache[request_id]["mel"], self.runtime_cache[request_id]["source"], self.runtime_cache[request_id]["speech"], ]) # Create and execute inference request inference_request = pb_utils.InferenceRequest( model_name='token2wav_dit', requested_output_names=[ "waveform", "conformer_cnn_cache", "conformer_att_cache", "estimator_cnn_cache", "estimator_att_cache", "mel", "source", "speech", ], inputs=inputs_tensor, request_id=request_id, parameters={"priority": index+1}, ) inference_response = await inference_request.async_exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) self.runtime_cache[request_id]["conformer_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_cnn_cache") self.runtime_cache[request_id]["conformer_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "conformer_att_cache") self.runtime_cache[request_id]["estimator_cnn_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_cnn_cache") self.runtime_cache[request_id]["estimator_att_cache"] = pb_utils.get_output_tensor_by_name(inference_response, "estimator_att_cache") self.runtime_cache[request_id]["mel"] = pb_utils.get_output_tensor_by_name(inference_response, "mel") self.runtime_cache[request_id]["source"] = pb_utils.get_output_tensor_by_name(inference_response, "source") self.runtime_cache[request_id]["speech"] = pb_utils.get_output_tensor_by_name(inference_response, "speech") # Extract and convert output waveform waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() return waveform def _extract_speech_feat(self, speech): speech_feat = mel_spectrogram( speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze( dim=0).transpose( 0, 1).to( self.device) speech_feat = speech_feat.unsqueeze(dim=0) return speech_feat async def _process_request(self, request): request_id = request.request_id() if request_id not in self.runtime_cache: self.runtime_cache[request_id] = { "conformer_cnn_cache": None, "conformer_att_cache": None, "estimator_cnn_cache": None, "estimator_att_cache": None, "mel": None, "source": None, "speech": None, } # Extract input tensors wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") # Process reference audio through audio tokenizer if wav is not None: wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len) prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0) wav_tensor = wav.as_numpy() wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]] print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}") prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor) speech_feat = self._extract_speech_feat(prompt_speech_resample) token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1]) prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half() prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous() reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() reference_text = reference_text[0][0].decode('utf-8') prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor) # reference_text = self.default_spk_info["prompt_text"] # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE # prompt_speech_feat = None # prompt_spk_embedding = None else: # using pre-cached reference text assert False, "using pre-cached reference text is not supported" reference_text = self.default_spk_info["prompt_text"] prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE prompt_speech_feat = None prompt_spk_embedding = None target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() target_text = target_text[0][0].decode('utf-8') print(f"target_text: {target_text}, time: {datetime.now()}") if self.decoupled: response_sender = request.get_response_sender() semantic_token_ids_arr = [] token_offset, chunk_index = 0, 0 start_time = time.time() this_token_hop_len = self.token_hop_len print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}") async for generated_ids in self.forward_llm_async( target_text=target_text, reference_text=reference_text, prompt_speech_tokens=prompt_speech_tokens, ): if not generated_ids: break semantic_token_ids_arr.append(generated_ids) print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}") while True: pending_num = len(semantic_token_ids_arr) - token_offset if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len: this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len] this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device) print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}") sub_tts_speech = await self.forward_token2wav( chunk_index, this_tts_speech_token, request_id, wav, wav_len, False ) print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}") audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) token_offset += this_token_hop_len self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}") if self.dynamic_chunk_strategy == "exponential": this_token_hop_len = self.token_frame_rate * (2 ** chunk_index) elif self.dynamic_chunk_strategy == "equal": this_token_hop_len = self.token_hop_len elif self.dynamic_chunk_strategy == "time_based": # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306 cost_time = time.time() - start_time duration = token_offset / self.token_frame_rate if chunk_index > 0 and cost_time > 0: avg_chunk_processing_time = cost_time / (chunk_index + 1) if avg_chunk_processing_time > 0: multiples = (duration - cost_time) / avg_chunk_processing_time self.logger.log_info(f"multiples: {multiples}") next_pending_num = len(semantic_token_ids_arr) - token_offset if multiples > 4: this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len elif multiples > 2: this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len else: this_token_hop_len = self.token_hop_len this_token_hop_len = max(self.token_hop_len, this_token_hop_len) chunk_index += 1 else: break this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device) sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True) audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) response_sender.send(inference_response) if request_id in self.runtime_cache: del self.runtime_cache[request_id] self.logger.log_info(f"Deleted cache for request_id: {request_id}") response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) self.logger.log_info("send tritonserver_response_complete_final to end") else: raise NotImplementedError("Decoupled mode is not supported") async def execute(self, requests): """Execute inference on the batched requests. Args: requests: List of inference requests Returns: List of inference responses containing generated audio """ tasks = [ asyncio.create_task(self._process_request(request)) for request in requests ] await asyncio.gather(*tasks) return None def finalize(self): self.logger.log_info("Finalizing CosyVoice DIT model") if hasattr(self, "http_client"): asyncio.run(self.http_client.aclose())