Browse Source

add cosyvoice3

root 1 month ago
parent
commit
8f0b28b861

+ 1 - 0
runtime/triton_trtllm/client_grpc.py

@@ -281,6 +281,7 @@ def get_args():
         choices=[
             "f5_tts",
             "spark_tts",
+            "cosyvoice3",
             "cosyvoice2",
             "cosyvoice2_dit"],
         help="triton model_repo module name to request",

+ 4 - 3
runtime/triton_trtllm/client_http.py

@@ -37,14 +37,14 @@ def get_args():
     parser.add_argument(
         "--server-url",
         type=str,
-        default="localhost:8000",
+        default="localhost:18000",
         help="Address of the server",
     )
 
     parser.add_argument(
         "--reference-audio",
         type=str,
-        default="../../example/prompt_audio.wav",
+        default="./prompt_audio.wav",
         help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
     )
 
@@ -65,9 +65,10 @@ def get_args():
     parser.add_argument(
         "--model-name",
         type=str,
-        default="spark_tts",
+        default="cosyvoice3",
         choices=[
             "f5_tts",
+            "cosyvoice3",
             "spark_tts",
             "cosyvoice2"],
         help="triton model_repo module name to request",

+ 512 - 0
runtime/triton_trtllm/infer_cosyvoice3.py

@@ -0,0 +1,512 @@
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 infer_cosyvoice3_token2wav.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $token2wav_model_dir \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+"""
+import argparse
+import json
+import os
+import time
+import asyncio
+
+import torch
+import torchaudio
+import s3tokenizer
+import soundfile as sf
+import requests
+import httpx
+from transformers import AutoTokenizer
+from datasets import load_dataset
+from torch.utils.data import DataLoader
+from functools import partial
+from tqdm import tqdm
+
+from token2wav_cosyvoice3 import CosyVoice3_Token2Wav
+
+try:
+    torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+    pass
+
+
+async def send_request_async(client, url, payload):
+    response = await client.post(url, json=payload, timeout=None)
+    response.raise_for_status()
+    response_json = response.json()
+    return response_json['choices'][0]['message']['content']
+
+
+async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
+    async with httpx.AsyncClient() as client:
+        tasks = []
+        for chat in chats:
+            payload = {
+                "model": model_name,
+                "messages": chat,
+                "max_tokens": 2048,
+                "temperature": temperature,
+                "top_p": top_p,
+                "top_k": top_k,
+                "repetition_penalty": 1.1,
+                "stop": ["<|eos1|>", "<|eos|>"],
+                "stream": False,
+            }
+            tasks.append(send_request_async(client, api_base, payload))
+        return await asyncio.gather(*tasks)
+
+
+def extract_speech_ids(speech_tokens_str):
+    """Extract speech IDs from token strings like <|s_23456|>"""
+    speech_ids = []
+    for token_str in speech_tokens_str:
+        if token_str.startswith('<|s_') and token_str.endswith('|>'):
+            num_str = token_str[4:-2]
+            num = int(num_str)
+            speech_ids.append(num)
+        else:
+            print(f"Unexpected token: {token_str}")
+    return speech_ids
+
+
+def convert_cosy3_tokens_to_speech_id_str(cosy3_tokens):
+    """Convert CosyVoice3 tokens to speech IDs string like <|s_23456|>"""
+    if hasattr(cosy3_tokens, 'cpu'):
+        cosy3_tokens = cosy3_tokens.cpu().numpy().tolist()
+    speech_id_str = ""
+    for token in cosy3_tokens:
+        speech_id_str += f"<|s_{token}|>"
+    return speech_id_str
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice3")
+    parser.add_argument(
+        "--split-name", type=str, default="wenetspeech4tts",
+        help="huggingface dataset split name",
+    )
+    parser.add_argument(
+        "--output-dir", required=True, type=str, help="dir to save result",
+    )
+    parser.add_argument(
+        "--batch-size", default=1, type=int,
+        help="batch size (per-device) for LLM inference",
+    )
+    parser.add_argument(
+        "--token2wav-batch-size", default=1, type=int,
+        help="batch size (per-device) for token2wav inference",
+    )
+    parser.add_argument(
+        "--num-workers", type=int, default=0, help="workers for dataloader",
+    )
+    parser.add_argument(
+        "--prefetch", type=int, default=None, help="prefetch for dataloader",
+    )
+    parser.add_argument(
+        "--llm-model-name-or-path", required=True, type=str,
+        help="CosyVoice3 HF LLM path (e.g. ./hf_cosyvoice3_llm)",
+    )
+    parser.add_argument(
+        "--token2wav-path", required=True, type=str,
+        help="CosyVoice3 model path (e.g. /workspace_yuekai/HF/Fun-CosyVoice3-0.5B-2512)",
+    )
+    parser.add_argument(
+        "--enable-trt", action="store_true",
+        help="Enable TensorRT for flow decoder estimator",
+    )
+    parser.add_argument(
+        "--streaming", action="store_true",
+        help="Enable streaming for flow decoder estimator",
+    )
+    parser.add_argument(
+        "--top-p", type=float, default=0.95, help="top p for sampling",
+    )
+    parser.add_argument(
+        "--temperature", type=float, default=0.8, help="temperature for sampling",
+    )
+    parser.add_argument(
+        "--top-k", type=int, default=15, help="top k for sampling",
+    )
+    parser.add_argument(
+        "--backend", type=str, default="hf",
+        choices=["hf", "trtllm", "vllm", "trtllm-serve"],
+        help="Backend to use for LLM inference",
+    )
+    parser.add_argument(
+        "--engine-dir", type=str, default=None,
+        help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
+    )
+    parser.add_argument(
+        "--kv-cache-free-gpu-memory-fraction", type=float, default=0.6,
+        help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
+    )
+    parser.add_argument(
+        "--openai-api-base", type=str,
+        default="http://localhost:8000/v1/chat/completions",
+        help="OpenAI API base URL (for trtllm-serve backend)",
+    )
+    parser.add_argument(
+        "--openai-model-name", type=str, default="trt_engines_bfloat16",
+        help="Model name to use with OpenAI API (for trtllm-serve backend)",
+    )
+    parser.add_argument(
+        "--epoch", type=int, default=1, help="Epoch to run",
+    )
+    return parser.parse_args()
+
+
+def data_collator(batch, tokenizer, s3_tokenizer):
+    """Data collator: extracts cosy3 tokens from prompt_audio using v3 s3 tokenizer."""
+    device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
+    target_sample_rate = 16000
+
+    input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
+    mels, prompt_audio_cosy3tokens_list, full_text_list = [], [], []
+    chat_list = []
+
+    for item in batch:
+        prompt_text, target_text = item["prompt_text"], item["target_text"]
+        prompt_text_list.append(prompt_text)
+        full_text = 'You are a helpful assistant.<|endofprompt|>' + prompt_text + target_text
+        full_text_list.append(full_text)
+
+        # Get prompt audio (convert to 16kHz for s3 tokenizer)
+        ref_audio = torch.from_numpy(item["prompt_audio"]["array"]).float().unsqueeze(0)
+        ref_sr = item["prompt_audio"]["sampling_rate"]
+        if ref_sr != target_sample_rate:
+            ref_audio = torchaudio.transforms.Resample(ref_sr, target_sample_rate)(ref_audio)
+        prompt_audio_list.append(ref_audio)
+
+        # Extract cosy3 tokens from prompt_audio using v3 s3 tokenizer
+        mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
+
+    # Batch tokenization with v3 tokenizer
+    if len(mels) > 0:
+        mels_padded, mels_lens = s3tokenizer.padding(mels)
+        codes, codes_lens = s3_tokenizer.quantize(mels_padded.to(device), mels_lens.to(device))
+        for i in range(len(codes)):
+            prompt_audio_cosy3tokens_list.append(codes[i, :codes_lens[i].item()])
+
+    # Build LLM inputs
+    for i, prompt_audio_cosy3tokens in enumerate(prompt_audio_cosy3tokens_list):
+        prompt_audio_cosy3_id_str = convert_cosy3_tokens_to_speech_id_str(
+            prompt_audio_cosy3tokens)
+        chat = [
+            {"role": "user", "content": full_text_list[i]},
+            {"role": "assistant", "content": prompt_audio_cosy3_id_str}
+        ]
+        chat_list.append(chat)
+        input_ids = tokenizer.apply_chat_template(
+            chat, tokenize=True, return_tensors='pt', continue_final_message=True)
+        input_ids_list.append(input_ids.squeeze(0))
+
+    ids = [item["id"] for item in batch]
+
+    return {
+        "input_ids": input_ids_list,
+        "ids": ids,
+        "prompt_text": prompt_text_list,
+        "prompt_audio_list": prompt_audio_list,
+        "chat_list": chat_list,
+    }
+
+
+def main(args):
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    assert torch.cuda.is_available()
+    local_rank = 0
+    device = torch.device(f"cuda:{local_rank}")
+
+    tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
+
+    if args.backend == "hf":
+        model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
+        model.eval()
+        model.to(device)
+        runner = None
+    elif args.backend == "trtllm":
+        if args.engine_dir is None:
+            raise ValueError("--engine-dir is required when backend is 'trtllm'")
+        runtime_rank = tensorrt_llm.mpi_rank()
+        model = None
+        runner_kwargs = dict(
+            engine_dir=args.engine_dir,
+            rank=runtime_rank,
+            max_output_len=2048,
+            enable_context_fmha_fp32_acc=False,
+            max_batch_size=args.batch_size,
+            max_input_len=512,
+            kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
+            cuda_graph_mode=False,
+            gather_generation_logits=False,
+        )
+        runner = ModelRunnerCpp.from_dir(**runner_kwargs)
+    elif args.backend == "vllm":
+        model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
+        runner = None
+    elif args.backend == "trtllm-serve":
+        model = None
+        runner = None
+    else:
+        raise ValueError(f"Unsupported backend: {args.backend}")
+
+    token2wav_model = CosyVoice3_Token2Wav(
+        model_dir=args.token2wav_path, enable_trt=args.enable_trt, device_id=local_rank, streaming=args.streaming
+    )
+
+    # Load v3 s3 tokenizer for prompt audio tokenization in data_collator
+    s3_tokenizer = s3tokenizer.load_model(
+        f"{args.token2wav_path}/speech_tokenizer_v3.onnx"
+    ).to(device).eval()
+
+    dataset = load_dataset(
+        "yuekai/seed_tts_cosy2",
+        split=args.split_name,
+        trust_remote_code=True,
+    )
+
+    dataloader = DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        prefetch_factor=args.prefetch,
+        collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
+    )
+
+    for epoch in range(args.epoch):
+        print(f"Running epoch {epoch}")
+        total_llm_time = 0
+        total_token2wav_time = 0
+        total_data_load_time = 0
+        total_llm_post_processing_time = 0
+        total_audio_save_time = 0
+        total_audio_samples = 0
+        start_time = time.time()
+
+        progress_bar = tqdm(total=len(dataset), desc="Processing", unit="wavs")
+
+        last_batch_end_time = time.time()
+        for batch in dataloader:
+            data_loaded_time = time.time()
+            total_data_load_time += data_loaded_time - last_batch_end_time
+
+            with torch.no_grad():
+                llm_start_time = time.time()
+
+                if args.backend == "hf":
+                    input_ids_list = batch["input_ids"]
+                    if len(input_ids_list) == 1:
+                        input_ids = input_ids_list[0].unsqueeze(0)
+                        attention_mask = torch.ones_like(input_ids)
+                    else:
+                        max_len = max([len(ids) for ids in input_ids_list])
+                        input_ids_list_new = [
+                            torch.cat([ids, torch.full((max_len - len(ids),), tokenizer.pad_token_id)])
+                            for ids in input_ids_list
+                        ]
+                        input_ids = torch.stack(input_ids_list_new)
+                        attention_mask = torch.zeros_like(input_ids)
+                        for i in range(len(input_ids_list)):
+                            attention_mask[i, :len(input_ids_list[i])] = 1
+
+                    outputs = model.generate(
+                        input_ids=input_ids.to(device),
+                        attention_mask=attention_mask.to(device),
+                        max_new_tokens=2048,
+                        do_sample=True,
+                        top_p=args.top_p,
+                        temperature=args.temperature,
+                        repetition_penalty=1.1,
+                        top_k=args.top_k,
+                    )
+                    torch.cuda.synchronize()
+
+                elif args.backend == "trtllm":
+                    batch_input_ids = list(batch["input_ids"])
+                    input_lengths = [x.size(0) for x in batch_input_ids]
+
+                    end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
+                    outputs = runner.generate(
+                        batch_input_ids=batch_input_ids,
+                        max_new_tokens=2048,
+                        end_id=end_id,
+                        pad_id=end_id,
+                        temperature=args.temperature,
+                        top_k=args.top_k,
+                        top_p=args.top_p,
+                        repetition_penalty=1.1,
+                        num_return_sequences=1,
+                        streaming=False,
+                        output_sequence_lengths=True,
+                        output_generation_logits=False,
+                        return_dict=True,
+                        return_all_generated_tokens=False
+                    )
+                    torch.cuda.synchronize()
+                    output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
+                    num_output_sents, num_beams, _ = output_ids.size()
+                    assert num_beams == 1
+                    batch_size = len(batch["input_ids"])
+                    num_return_sequences = num_output_sents // batch_size
+                    assert num_return_sequences == 1
+                    outputs = []
+                    for i in range(batch_size * num_return_sequences):
+                        batch_idx = i // num_return_sequences
+                        output_begin = input_lengths[batch_idx]
+                        output_end = sequence_lengths[i][0]
+                        outputs_i = output_ids[i][0][:output_end].tolist()
+                        outputs.append(outputs_i)
+
+                elif args.backend == "vllm":
+                    input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
+                    sampling_params = SamplingParams(
+                        temperature=args.temperature,
+                        top_p=args.top_p,
+                        top_k=args.top_k,
+                        repetition_penalty=1.1,
+                        max_tokens=2048,
+                    )
+                    outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
+                    for j, output in enumerate(outputs):
+                        outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
+
+                elif args.backend == "trtllm-serve":
+                    if args.batch_size > 1:
+                        outputs = asyncio.run(send_batch_requests_async(
+                            args.openai_api_base,
+                            args.openai_model_name,
+                            batch["chat_list"],
+                            args.temperature,
+                            args.top_p,
+                            args.top_k,
+                        ))
+                    else:
+                        outputs = []
+                        for chat in batch["chat_list"]:
+                            payload = {
+                                "model": args.openai_model_name,
+                                "messages": chat,
+                                "max_tokens": 2048,
+                                "temperature": args.temperature,
+                                "top_p": args.top_p,
+                                "top_k": args.top_k,
+                                "repetition_penalty": 1.1,
+                                "stop": ["<|eos1|>", "<|eos|>"],
+                                "stream": False,
+                            }
+                            response = requests.post(args.openai_api_base, json=payload)
+                            response.raise_for_status()
+                            response_json = response.json()
+                            generated_content = response_json['choices'][0]['message']['content']
+                            outputs.append(generated_content)
+
+                llm_end_time = time.time()
+                total_llm_time += (llm_end_time - llm_start_time)
+
+                items_for_token_2wav = []
+                for i in range(len(batch["ids"])):
+                    llm_post_processing_start_time = time.time()
+                    if args.backend == "trtllm-serve":
+                        speech_tokens_str = outputs[i].strip().split('><')
+                        if len(speech_tokens_str) > 1:
+                            speech_tokens_str = [
+                                t if t.startswith('<') else '<' + t for t in speech_tokens_str
+                            ]
+                            speech_tokens_str = [
+                                t if t.endswith('>') else t + '>' for t in speech_tokens_str
+                            ]
+                        speech_ids = extract_speech_ids(speech_tokens_str)
+                    else:
+                        input_length = len(batch["input_ids"][i])
+                        generated_ids = outputs[i][input_length:]
+                        speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+                        speech_ids = extract_speech_ids(speech_tokens_str)
+                    print(i, speech_ids[:10], "...", f"total={len(speech_ids)}")
+                    if len(speech_ids) == 0:
+                        print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
+                        llm_post_processing_end_time = time.time()
+                        total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
+                        continue
+
+                    current_prompt_audio = batch["prompt_audio_list"][i]
+
+                    llm_post_processing_end_time = time.time()
+                    total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
+
+                    items_for_token_2wav.append({
+                        "speech_ids": speech_ids,
+                        "prompt_audio": current_prompt_audio.squeeze(0),
+                        "id": batch["ids"][i]
+                    })
+
+                for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
+                    t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
+                    if not t2w_batch:
+                        continue
+
+                    t2w_speech_tokens = [item["speech_ids"] for item in t2w_batch]
+                    t2w_prompt_audios = [item["prompt_audio"] for item in t2w_batch]
+                    t2w_sample_rates = [16000] * len(t2w_batch)
+
+                    token2wav_start_time = time.time()
+                    generated_wavs = token2wav_model(
+                        t2w_speech_tokens, t2w_prompt_audios, t2w_sample_rates,
+                        streaming=args.streaming,
+                    )
+                    token2wav_end_time = time.time()
+                    total_token2wav_time += (token2wav_end_time - token2wav_start_time)
+
+                    audio_save_start_time = time.time()
+                    for j, audio_hat in enumerate(generated_wavs):
+                        wav = audio_hat.squeeze().cpu().numpy()
+                        total_audio_samples += len(wav)
+                        sf.write(f"{args.output_dir}/{t2w_batch[j]['id']}.wav", wav, 24000)
+                        print(f"Generated audio for sample {t2w_batch[j]['id']} with {len(t2w_speech_tokens[j])} tokens")
+                    audio_save_end_time = time.time()
+                    total_audio_save_time += audio_save_end_time - audio_save_start_time
+
+            progress_bar.update(len(batch["ids"]))
+            last_batch_end_time = time.time()
+
+        progress_bar.close()
+        end_time = time.time()
+        total_audio_duration_seconds = total_audio_samples / 24000
+
+        log_file_path = os.path.join(args.output_dir, "log.txt")
+        with open(log_file_path, 'w') as f:
+            log_data = {
+                "args": vars(args),
+                "data_load_time_seconds": total_data_load_time,
+                "llm_time_seconds": total_llm_time,
+                "llm_post_processing_time_seconds": total_llm_post_processing_time,
+                "token2wav_time_seconds": total_token2wav_time,
+                "audio_save_time_seconds": total_audio_save_time,
+                "total_audio_duration_seconds": total_audio_duration_seconds,
+                "pipeline_time_seconds": end_time - start_time,
+            }
+            print(log_data)
+            f.write(json.dumps(log_data, indent=4))
+        print(f"Metrics logged to {log_file_path}")
+
+
+if __name__ == "__main__":
+    args = get_args()
+    if args.backend == "vllm":
+        from vllm import LLM, SamplingParams
+    elif args.backend == "trtllm":
+        import tensorrt_llm
+        from tensorrt_llm.runtime import ModelRunnerCpp
+    elif args.backend == "hf":
+        from transformers import AutoModelForCausalLM
+    elif args.backend == "trtllm-serve":
+        pass
+    else:
+        raise ValueError(f"Unsupported backend: {args.backend}")
+    main(args)

+ 90 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/audio_tokenizer/1/model.py

@@ -0,0 +1,90 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import json
+import torch
+from torch.utils.dlpack import to_dlpack
+
+import triton_python_backend_utils as pb_utils
+
+import os
+import numpy as np
+import s3tokenizer
+torch.set_num_threads(1)
+# ORIGINAL_VOCAB_SIZE = 151924
+
+
+class TritonPythonModel:
+    """Triton Python model for audio tokenization.
+
+    This model takes reference audio input and extracts semantic tokens
+    using s3tokenizer.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        # Parse model parameters
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+
+        self.device = torch.device("cuda")
+        model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v3.onnx")
+        self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
+
+    def execute(self, requests):
+        """Execute inference on the batched requests."""
+        mels = []
+
+        # Process each request in batch
+        for req_idx, request in enumerate(requests):
+            # Extract input tensors
+            wav_array = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav").as_numpy()
+            wav_len = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav_len").as_numpy().item()
+
+            wav_array = torch.from_numpy(wav_array).to(self.device)
+            # Prepare inputs
+            wav = wav_array[:, :wav_len].squeeze(0)
+            mel = s3tokenizer.log_mel_spectrogram(wav)
+            mels.append(mel)
+
+        mels, mels_lens = s3tokenizer.padding(mels)
+        codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
+
+        responses = []
+        for i in range(len(requests)):
+            prompt_speech_tokens = codes[i, :codes_lens[i].item()]
+            prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
+                "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
+            inference_response = pb_utils.InferenceResponse(
+                output_tensors=[prompt_speech_tokens_tensor])
+            responses.append(inference_response)
+
+        return responses

+ 53 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/audio_tokenizer/config.pbtxt

@@ -0,0 +1,53 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "audio_tokenizer"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+parameters [
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+  }
+]
+output [
+  {
+    name: "prompt_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 492 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/cosyvoice3/1/model.py

@@ -0,0 +1,492 @@
+import json
+import re
+import time
+import asyncio
+
+import numpy as np
+import torch
+from torch.utils.dlpack import to_dlpack
+import triton_python_backend_utils as pb_utils
+
+import httpx
+import torchaudio
+from functools import partial
+from matcha.utils.audio import mel_spectrogram as matcha_mel_spectrogram
+
+
+torch.set_num_threads(1)
+
+# CosyVoice3 mel params: fmax=None (Nyquist), center=False
+mel_spectrogram = partial(matcha_mel_spectrogram,
+    n_fft=1920, num_mels=80, sampling_rate=24000,
+    hop_size=480, win_size=1920, fmin=0, fmax=None, center=False)
+
+
+def parse_speech_token_string(response_text):
+    """Parse speech tokens from string like '<|s_123|><|s_456|>' into list of int IDs."""
+    speech_tokens = response_text.strip().split('><')
+    if len(speech_tokens) > 1:
+        speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
+        speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
+    speech_ids = []
+    for token_str in speech_tokens:
+        match = re.match(r'<\|s_(\d+)\|>', token_str)
+        if match:
+            speech_ids.append(int(match.group(1)))
+    return speech_ids
+
+
+class TritonPythonModel:
+    """CosyVoice3 BLS orchestrator for Triton Inference Server.
+
+    Orchestrates: audio_tokenizer, speaker_embedding, remote LLM (httpx),
+    token2wav (flow-only), and vocoder (CausalHiFTGenerator).
+    Supports both streaming (decoupled) and offline (non-decoupled) modes.
+    """
+
+    def initialize(self, args):
+        self.logger = pb_utils.Logger
+        self.model_config = json.loads(args['model_config'])
+        parameters = self.model_config['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+
+        self.device = torch.device("cuda")
+        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
+
+        # Streaming config
+        self.token_frame_rate = 25
+        self.flow_pre_lookahead_len = 3
+        self.token_hop_len = 15
+        self.token_mel_ratio = 2
+        self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")
+        self.logger.log_info(f"CosyVoice3 BLS initialized, decoupled={self.decoupled}, "
+                             f"chunk_strategy={self.dynamic_chunk_strategy}")
+
+        # HTTP client for remote LLM (trtllm-serve default port: 8000)
+        self.http_client = httpx.AsyncClient()
+        self.api_base = model_params.get("llm_api_base", "http://localhost:8000/v1/chat/completions")
+
+        # Speaker cache to avoid redundant audio_tokenizer/speaker_embedding calls
+        self.speaker_cache = {}
+
+    def _convert_speech_tokens_to_str(self, speech_tokens):
+        """Convert speech token IDs tensor/list to string like '<|s_N|>'."""
+        if isinstance(speech_tokens, torch.Tensor):
+            speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
+        return "".join(f"<|s_{int(tid)}|>" for tid in speech_tokens)
+
+    def _extract_speech_feat(self, speech):
+        """Extract mel spectrogram from 24kHz speech for flow prompt."""
+        speech_feat = mel_spectrogram(speech).squeeze(dim=0).transpose(0, 1)
+        speech_feat = speech_feat.unsqueeze(dim=0).to(self.device)
+        return speech_feat
+
+    async def forward_llm_streaming(self, target_text, reference_text, prompt_speech_tokens):
+        """Async generator: stream LLM tokens via httpx SSE."""
+        full_text = f"{reference_text}{target_text}"
+        prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
+
+        chat = [
+            {"role": "user", "content": full_text},
+            {"role": "assistant", "content": prompt_speech_tokens_str}
+        ]
+        payload = {
+            "model": "trt_engines_bfloat16",
+            "messages": chat,
+            "max_tokens": 750,
+            "temperature": 0.8,
+            "top_p": 0.95,
+            "top_k": 50,
+            "repetition_penalty": 1.1,
+            "stop": ["<|eos1|>", "<|eos|>"],
+            "stream": True,
+        }
+
+        buffer = ""
+        async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
+            response.raise_for_status()
+            async for line in response.aiter_lines():
+                if line.startswith("data: "):
+                    line_data = line[len("data: "):].strip()
+                    if line_data == "[DONE]":
+                        break
+                    try:
+                        json_data = json.loads(line_data)
+                        content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
+                        if content:
+                            buffer += content
+                            while True:
+                                match = re.search(r"<\|s_(\d+)\|>", buffer)
+                                if not match:
+                                    break
+                                token_num = int(match.group(1))
+                                # final_id = token_num + ORIGINAL_VOCAB_SIZE
+                                yield token_num
+                                buffer = buffer[match.end():]
+                    except json.JSONDecodeError:
+                        continue
+
+        # Flush remaining tokens
+        while True:
+            match = re.search(r"<\|s_(\d+)\|>", buffer)
+            if not match:
+                break
+            token_num = int(match.group(1))
+            #final_id = token_num + ORIGINAL_VOCAB_SIZE
+            yield token_num
+            buffer = buffer[match.end():]
+
+    async def forward_llm_offline(self, target_text, reference_text, prompt_speech_tokens):
+        """Non-streaming LLM call, returns all speech token IDs at once."""
+        full_text = f"{reference_text}{target_text}"
+        prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
+
+        chat = [
+            {"role": "user", "content": full_text},
+            {"role": "assistant", "content": prompt_speech_tokens_str}
+        ]
+        payload = {
+            "model": "trt_engines_bfloat16",
+            "messages": chat,
+            "max_tokens": 750,
+            "temperature": 0.8,
+            "top_p": 0.95,
+            "top_k": 50,
+            "repetition_penalty": 1.1,
+            "stop": ["<|eos1|>", "<|eos|>"],
+            "stream": False,
+        }
+        response = await self.http_client.post(self.api_base, json=payload, timeout=None)
+        response.raise_for_status()
+        response_json = response.json()
+        generated_content = response_json['choices'][0]['message']['content']
+        speech_ids = parse_speech_token_string(generated_content)
+        # return [sid + ORIGINAL_VOCAB_SIZE for sid in speech_ids]
+        return speech_ids
+
+    def forward_audio_tokenizer(self, wav, wav_len):
+        """BLS call to audio_tokenizer."""
+        inference_request = pb_utils.InferenceRequest(
+            model_name='audio_tokenizer',
+            requested_output_names=['prompt_speech_tokens'],
+            inputs=[wav, wav_len]
+        )
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+        prompt_speech_tokens = pb_utils.get_output_tensor_by_name(
+            inference_response, 'prompt_speech_tokens')
+        return torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
+
+    def forward_speaker_embedding(self, wav):
+        """BLS call to speaker_embedding."""
+        inference_request = pb_utils.InferenceRequest(
+            model_name='speaker_embedding',
+            requested_output_names=['prompt_spk_embedding'],
+            inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
+        )
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+        prompt_spk_embedding = pb_utils.get_output_tensor_by_name(
+            inference_response, 'prompt_spk_embedding')
+        return torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
+
+    async def forward_token2wav(self, target_speech_tokens, prompt_speech_tokens,
+                                prompt_speech_feat, prompt_spk_embedding,
+                                request_id, token_offset=None, finalize=True,
+                                priority=100):
+        """Async BLS call to token2wav (flow-only). Returns mel tensor."""
+        target_tokens_pb = pb_utils.Tensor.from_dlpack(
+            "target_speech_tokens", to_dlpack(target_speech_tokens))
+        prompt_tokens_pb = pb_utils.Tensor.from_dlpack(
+            "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
+        prompt_feat_pb = pb_utils.Tensor.from_dlpack(
+            "prompt_speech_feat", to_dlpack(prompt_speech_feat))
+        prompt_emb_pb = pb_utils.Tensor.from_dlpack(
+            "prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
+
+        inputs = [target_tokens_pb, prompt_tokens_pb, prompt_feat_pb, prompt_emb_pb]
+
+        if token_offset is not None:
+            inputs.append(pb_utils.Tensor("token_offset",
+                          np.array([[token_offset]], dtype=np.int32)))
+            inputs.append(pb_utils.Tensor("finalize",
+                          np.array([[finalize]], dtype=np.bool_)))
+
+        inference_request = pb_utils.InferenceRequest(
+            model_name='token2wav',
+            requested_output_names=['mel'],
+            inputs=inputs,
+            request_id=request_id,
+            parameters={"priority": priority},
+        )
+
+        inference_response = await inference_request.async_exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        mel = pb_utils.get_output_tensor_by_name(inference_response, 'mel')
+        return torch.utils.dlpack.from_dlpack(mel.to_dlpack())
+
+    async def forward_vocoder(self, mel, finalize):
+        """Async BLS call to vocoder. Returns speech tensor."""
+        if mel.dim() == 2:
+            mel = mel.unsqueeze(0)  # [80, T] -> [1, 80, T]
+        mel_pb = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel.float()))
+        finalize_pb = pb_utils.Tensor("finalize",
+                      np.array([[finalize]], dtype=np.bool_))
+
+        inference_request = pb_utils.InferenceRequest(
+            model_name='vocoder',
+            requested_output_names=['tts_speech'],
+            inputs=[mel_pb, finalize_pb],
+        )
+
+        inference_response = await inference_request.async_exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        speech = pb_utils.get_output_tensor_by_name(inference_response, 'tts_speech')
+        return torch.utils.dlpack.from_dlpack(speech.to_dlpack()).cpu()
+
+    def _prepare_prompt(self, request):
+        """Extract reference audio, tokenize, compute speaker embedding and mel feat."""
+        wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
+        wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+
+        reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text")
+        reference_text = reference_text.as_numpy()[0][0].decode('utf-8') if reference_text is not None else ""
+        if '<|endofprompt|>' not in reference_text:
+            reference_text = 'You are a helpful assistant.<|endofprompt|>' + reference_text
+
+        # Check speaker cache
+        if reference_text in self.speaker_cache:
+            cached = self.speaker_cache[reference_text]
+            return (cached['prompt_speech_tokens_for_llm'], cached['prompt_speech_tokens'],
+                    cached['prompt_speech_feat'], cached['prompt_spk_embedding'], reference_text)
+
+        # Audio tokenizer
+        wav_np = wav.as_numpy()
+        wav_len_val = wav_len.as_numpy()[0][0]
+        prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
+        prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)  # [1, T]
+
+        # Speaker embedding
+        wav_tensor = torch.from_numpy(wav_np)
+        wav_tensor = wav_tensor[:, :wav_len_val]
+        prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+
+        # Mel extraction at 24kHz with CosyVoice3 params
+        prompt_speech_resample = torchaudio.transforms.Resample(
+            orig_freq=16000, new_freq=24000)(wav_tensor)
+        speech_feat = self._extract_speech_feat(prompt_speech_resample)
+
+        # Keep full tokens for LLM prefill (untruncated)
+        prompt_speech_tokens_for_llm = prompt_speech_tokens.clone()
+
+        # Align prompt speech feat and tokens to 2:1 ratio (for flow model only)
+        orig_feat_len = speech_feat.shape[1]
+        orig_token_len = prompt_speech_tokens.shape[-1]
+        token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
+        prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
+        prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
+
+        # Cache
+        self.speaker_cache[reference_text] = {
+            'prompt_speech_tokens_for_llm': prompt_speech_tokens_for_llm,
+            'prompt_speech_tokens': prompt_speech_tokens,
+            'prompt_speech_feat': prompt_speech_feat,
+            'prompt_spk_embedding': prompt_spk_embedding,
+        }
+
+        return prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, reference_text
+
+    async def _process_request_streaming(self, request):
+        """Process a single request in streaming (decoupled) mode."""
+        request_id = request.request_id()
+        response_sender = request.get_response_sender()
+
+        try:
+            prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, \
+                prompt_spk_embedding, reference_text = self._prepare_prompt(request)
+
+            target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
+            target_text = target_text[0][0].decode('utf-8')
+
+            semantic_token_ids_arr = []
+            token_offset = 0
+            chunk_index = 0
+            this_token_hop_len = self.token_hop_len
+            accumulated_mel = None
+            speech_offset = 0
+            start_time = time.time()
+
+            async for generated_id in self.forward_llm_streaming(
+                target_text=target_text,
+                reference_text=reference_text,
+                prompt_speech_tokens=prompt_speech_tokens_for_llm,
+            ):
+                semantic_token_ids_arr.append(generated_id)
+
+                while True:
+                    pending_num = len(semantic_token_ids_arr) - token_offset
+                    if pending_num < this_token_hop_len + self.flow_pre_lookahead_len:
+                        break
+
+                    # Prepare tokens for this chunk
+                    end_idx = token_offset + this_token_hop_len + self.flow_pre_lookahead_len
+                    this_tokens = torch.tensor(
+                        semantic_token_ids_arr[:end_idx]
+                    ).unsqueeze(0).to(torch.int32).to(self.device)
+
+                    # Call token2wav (flow-only) -> mel_chunk
+                    mel_chunk = await self.forward_token2wav(
+                        this_tokens, prompt_speech_tokens,
+                        prompt_speech_feat, prompt_spk_embedding,
+                        request_id, token_offset=token_offset, finalize=False,
+                        priority=chunk_index + 1,
+                    )
+
+                    # Accumulate mel
+                    if mel_chunk.dim() == 2:
+                        mel_chunk = mel_chunk.unsqueeze(0)
+                    if accumulated_mel is None:
+                        accumulated_mel = mel_chunk
+                    else:
+                        accumulated_mel = torch.cat([accumulated_mel, mel_chunk], dim=2)
+
+                    # Call vocoder
+                    speech = await self.forward_vocoder(accumulated_mel, finalize=False)
+
+                    # Extract new speech
+                    new_speech = speech[:, speech_offset:]
+                    speech_offset += new_speech.shape[1]
+
+                    if new_speech.shape[1] > 0:
+                        audio_tensor = pb_utils.Tensor.from_dlpack(
+                            "waveform", to_dlpack(new_speech))
+                        inference_response = pb_utils.InferenceResponse(
+                            output_tensors=[audio_tensor])
+                        response_sender.send(inference_response)
+
+                    token_offset += this_token_hop_len
+
+                    # Dynamic chunk strategy
+                    if self.dynamic_chunk_strategy == "exponential":
+                        this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
+                    elif self.dynamic_chunk_strategy == "time_based":
+                        cost_time = time.time() - start_time
+                        duration = token_offset / self.token_frame_rate
+                        if chunk_index > 0 and cost_time > 0:
+                            avg_chunk_time = cost_time / (chunk_index + 1)
+                            if avg_chunk_time > 0:
+                                multiples = (duration - cost_time) / avg_chunk_time
+                                next_pending = len(semantic_token_ids_arr) - token_offset
+                                if multiples > 4:
+                                    this_token_hop_len = (next_pending // self.token_hop_len + 1) * self.token_hop_len
+                                elif multiples > 2:
+                                    this_token_hop_len = (next_pending // self.token_hop_len) * self.token_hop_len
+                                else:
+                                    this_token_hop_len = self.token_hop_len
+                                this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
+
+                    chunk_index += 1
+
+            # Final chunk with remaining tokens
+            if len(semantic_token_ids_arr) > 0:
+                remaining_tokens = torch.tensor(
+                    semantic_token_ids_arr
+                ).unsqueeze(0).to(torch.int32).to(self.device)
+
+                mel_chunk = await self.forward_token2wav(
+                    remaining_tokens, prompt_speech_tokens,
+                    prompt_speech_feat, prompt_spk_embedding,
+                    request_id, token_offset=token_offset, finalize=True,
+                    priority=chunk_index + 1,
+                )
+
+                if mel_chunk.dim() == 2:
+                    mel_chunk = mel_chunk.unsqueeze(0)
+                if accumulated_mel is None:
+                    accumulated_mel = mel_chunk
+                else:
+                    accumulated_mel = torch.cat([accumulated_mel, mel_chunk], dim=2)
+
+                speech = await self.forward_vocoder(accumulated_mel, finalize=True)
+
+                new_speech = speech[:, speech_offset:]
+                if new_speech.shape[1] > 0:
+                    audio_tensor = pb_utils.Tensor.from_dlpack(
+                        "waveform", to_dlpack(new_speech))
+                    inference_response = pb_utils.InferenceResponse(
+                        output_tensors=[audio_tensor])
+                    response_sender.send(inference_response)
+
+            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+        except Exception as e:
+            self.logger.log_error(f"Error in streaming request: {e}")
+            error_response = pb_utils.InferenceResponse(
+                error=pb_utils.TritonError(str(e)))
+            response_sender.send(error_response)
+            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+
+    async def _process_request_offline(self, request):
+        """Process a single request in offline (non-decoupled) mode."""
+        request_id = request.request_id()
+
+        prompt_speech_tokens_for_llm, prompt_speech_tokens, prompt_speech_feat, \
+            prompt_spk_embedding, reference_text = self._prepare_prompt(request)
+
+        target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
+        target_text = target_text[0][0].decode('utf-8')
+
+        # Get all speech tokens at once (use full untruncated prompt tokens for LLM)
+        all_token_ids = await self.forward_llm_offline(
+            target_text=target_text,
+            reference_text=reference_text,
+            prompt_speech_tokens=prompt_speech_tokens_for_llm,
+        )
+
+        if len(all_token_ids) == 0:
+            raise pb_utils.TritonModelException("LLM generated no speech tokens")
+
+        all_tokens = torch.tensor(all_token_ids).unsqueeze(0).to(torch.int32).to(self.device)
+
+        # token2wav (no token_offset, finalize=True) -> full mel
+        mel = await self.forward_token2wav(
+            all_tokens, prompt_speech_tokens,
+            prompt_speech_feat, prompt_spk_embedding,
+            request_id,
+        )
+
+        # vocoder -> full speech
+        speech = await self.forward_vocoder(mel, finalize=True)
+
+        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(speech))
+        return pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+
+    async def execute(self, requests):
+        if self.decoupled:
+            tasks = [
+                asyncio.create_task(self._process_request_streaming(request))
+                for request in requests
+            ]
+            await asyncio.gather(*tasks)
+            return None
+        else:
+            responses = []
+            for request in requests:
+                try:
+                    response = await self._process_request_offline(request)
+                    responses.append(response)
+                except Exception as e:
+                    self.logger.log_error(f"Error in offline request: {e}")
+                    responses.append(pb_utils.InferenceResponse(
+                        error=pb_utils.TritonError(str(e))))
+            return responses
+
+    def finalize(self):
+        self.logger.log_info("Finalizing CosyVoice3 BLS model")
+        if hasattr(self, "http_client"):
+            asyncio.run(self.http_client.aclose())

+ 73 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/cosyvoice3/config.pbtxt

@@ -0,0 +1,73 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "cosyvoice3"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+model_transaction_policy {
+  decoupled: ${decoupled_mode}
+}
+parameters [
+  {
+   key: "llm_tokenizer_dir",
+   value: {string_value:"${llm_tokenizer_dir}"}
+  },
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+    optional: true
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+    optional: true
+  },
+  {
+    name: "reference_text"
+    data_type: TYPE_STRING
+    dims: [1]
+    optional: true
+  },
+  {
+    name: "target_text"
+    data_type: TYPE_STRING
+    dims: [1]
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: ${bls_instance_num}
+    kind: KIND_CPU
+  }
+]

+ 146 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/speaker_embedding/1/model.py

@@ -0,0 +1,146 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import json
+import torch
+from torch.utils.dlpack import to_dlpack
+
+import triton_python_backend_utils as pb_utils
+
+import os
+import numpy as np
+import torchaudio.compliance.kaldi as kaldi
+from cosyvoice.utils.file_utils import convert_onnx_to_trt
+from cosyvoice.utils.common import TrtContextWrapper
+import onnxruntime
+
+
+class TritonPythonModel:
+    """Triton Python model for audio tokenization.
+
+    This model takes reference audio input and extracts semantic tokens
+    using s3tokenizer.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        # Parse model parameters
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+
+        self.device = torch.device("cuda")
+
+        model_dir = model_params["model_dir"]
+        gpu = "l20"
+        enable_trt = True
+        if enable_trt:
+            self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
+                              f'{model_dir}/campplus.onnx',
+                              1,
+                              False)
+        else:
+            campplus_model = f'{model_dir}/campplus.onnx'
+            option = onnxruntime.SessionOptions()
+            option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+            option.intra_op_num_threads = 1
+            self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
+
+    def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
+        if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
+            trt_kwargs = self.get_spk_trt_kwargs()
+            convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
+        import tensorrt as trt
+        with open(spk_model, 'rb') as f:
+            spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
+        self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_spk_trt_kwargs(self):
+        min_shape = [(1, 4, 80)]
+        opt_shape = [(1, 500, 80)]
+        max_shape = [(1, 3000, 80)]
+        input_names = ["input"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def _extract_spk_embedding(self, speech):
+        feat = kaldi.fbank(speech,
+                           num_mel_bins=80,
+                           dither=0,
+                           sample_frequency=16000)
+        spk_feat = feat - feat.mean(dim=0, keepdim=True)
+
+        if isinstance(self.spk_model, onnxruntime.InferenceSession):
+            embedding = self.spk_model.run(
+                None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
+            )[0].flatten().tolist()
+            embedding = torch.tensor([embedding]).to(self.device)
+        else:
+            [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
+            # NOTE need to synchronize when switching stream
+            with torch.cuda.device(self.device):
+                torch.cuda.current_stream().synchronize()
+                spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
+                batch_size = spk_feat.size(0)
+
+                with stream:
+                    spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
+                    embedding = torch.empty((batch_size, 192), device=spk_feat.device)
+
+                    data_ptrs = [spk_feat.contiguous().data_ptr(),
+                                 embedding.contiguous().data_ptr()]
+                    for i, j in enumerate(data_ptrs):
+
+                        spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                    # run trt engine
+                    assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                    torch.cuda.current_stream().synchronize()
+                self.spk_model.release_estimator(spk_model, stream)
+
+        return embedding.half()
+
+    def execute(self, requests):
+        """Execute inference on the batched requests."""
+        responses = []
+        # Process each request in batch
+        for req_idx, request in enumerate(requests):
+            # Extract input tensors
+            wav_array = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav").as_numpy()
+            wav_array = torch.from_numpy(wav_array).to(self.device)
+
+            embedding = self._extract_spk_embedding(wav_array)
+
+            prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
+                "prompt_spk_embedding", to_dlpack(embedding))
+            inference_response = pb_utils.InferenceResponse(
+                output_tensors=[prompt_spk_embedding_tensor])
+
+            responses.append(inference_response)
+
+        return responses

+ 48 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/speaker_embedding/config.pbtxt

@@ -0,0 +1,48 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "speaker_embedding"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+parameters [
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+  }
+]
+output [
+  {
+    name: "prompt_spk_embedding"
+    data_type: TYPE_FP16
+    dims: [-1]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 200 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/token2wav/1/model.py

@@ -0,0 +1,200 @@
+import json
+import os
+import logging
+import queue
+
+import torch
+import numpy as np
+from torch.utils.dlpack import to_dlpack
+import triton_python_backend_utils as pb_utils
+from hyperpyyaml import load_hyperpyyaml
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+
+class TrtContextWrapper:
+    def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        self.trt_engine = trt_engine
+        self.device = device
+        for _ in range(trt_concurrent):
+            trt_context = trt_engine.create_execution_context()
+            trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
+            assert trt_context is not None
+            self.trt_context_pool.put([trt_context, trt_stream])
+
+    def acquire_estimator(self):
+        return self.trt_context_pool.get(), self.trt_engine
+
+    def release_estimator(self, context, stream):
+        self.trt_context_pool.put([context, stream])
+
+
+def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocast_mode=False):
+    import tensorrt as trt
+    logging.info("Converting onnx to trt...")
+    if autocast_mode:
+        network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
+    else:
+        network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+    trt_logger = trt.Logger(trt.Logger.INFO)
+    builder = trt.Builder(trt_logger)
+    network = builder.create_network(network_flags)
+    parser = trt.OnnxParser(network, trt_logger)
+    config = builder.create_builder_config()
+    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)
+    if not autocast_mode and fp16:
+        config.set_flag(trt.BuilderFlag.FP16)
+    profile = builder.create_optimization_profile()
+    with open(onnx_model, "rb") as f:
+        if not parser.parse(f.read()):
+            for error in range(parser.num_errors):
+                print(parser.get_error(error))
+            raise ValueError(f'failed to parse {onnx_model}')
+    for i in range(len(trt_kwargs['input_names'])):
+        profile.set_shape(trt_kwargs['input_names'][i],
+                          trt_kwargs['min_shape'][i],
+                          trt_kwargs['opt_shape'][i],
+                          trt_kwargs['max_shape'][i])
+    if not autocast_mode:
+        tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
+        for i in range(network.num_inputs):
+            network.get_input(i).dtype = tensor_dtype
+        for i in range(network.num_outputs):
+            network.get_output(i).dtype = tensor_dtype
+    config.add_optimization_profile(profile)
+    engine_bytes = builder.build_serialized_network(network, config)
+    with open(trt_model, "wb") as f:
+        f.write(engine_bytes)
+    logging.info("Successfully converted onnx to trt")
+
+torch.set_num_threads(1)
+
+
+class TritonPythonModel:
+    """Triton Python model for CosyVoice3 token2wav (flow-only, stateless).
+
+    Converts speech tokens to mel spectrogram using the CausalMaskedDiffWithDiT flow model.
+    """
+
+    def initialize(self, args):
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+        model_dir = model_params["model_dir"]
+
+        self.device = torch.device("cuda")
+
+        # Load flow model from cosyvoice3.yaml
+        with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
+            configs = load_hyperpyyaml(f, overrides={
+                'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
+            })
+        self.flow = configs['flow']
+        self.fp16 = True
+        self.flow.half()
+        self.flow.load_state_dict(
+            torch.load(os.path.join(model_dir, 'flow.pt'),
+                        map_location='cpu', weights_only=True),
+            strict=True
+        )
+        self.flow.to(self.device).eval()
+
+        # TRT acceleration for flow decoder estimator
+        self.load_trt(model_dir)
+
+        self.token_mel_ratio = self.flow.token_mel_ratio
+        logger.info(f"Token2wav (flow-only) initialized, token_mel_ratio={self.token_mel_ratio}")
+
+    def load_trt(self, model_dir, trt_concurrent=1):
+        device_id = torch.cuda.current_device()
+        onnx_path = os.path.join(model_dir, 'flow.decoder.estimator.autocast_fp16.onnx')
+        trt_path = os.path.join(model_dir, f'flow.decoder.estimator.autocast_fp16.{device_id}.plan')
+
+        if not os.path.exists(trt_path) or os.path.getsize(trt_path) == 0:
+            trt_kwargs = self.get_trt_kwargs()
+            convert_onnx_to_trt(trt_path, trt_kwargs, onnx_path,
+                                fp16=True, autocast_mode=True)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(trt_path, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, f'failed to load trt {trt_path}'
+        self.flow.decoder.estimator = TrtContextWrapper(
+            estimator_engine, trt_concurrent=trt_concurrent, device=str(self.device))
+
+    def get_trt_kwargs(self):
+        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
+        opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
+        max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
+        input_names = ["x", "mask", "mu", "cond"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape,
+                'max_shape': max_shape, 'input_names': input_names}
+
+    def execute(self, requests):
+        responses = []
+        for req_idx, request in enumerate(requests):
+            target_speech_tokens = pb_utils.get_input_tensor_by_name(
+                request, "target_speech_tokens")
+            target_speech_tokens = torch.utils.dlpack.from_dlpack(
+                target_speech_tokens.to_dlpack()).to(self.device)
+            if target_speech_tokens.dim() == 1:
+                target_speech_tokens = target_speech_tokens.unsqueeze(0)
+
+            # Optional inputs
+            prompt_speech_tokens_pb = pb_utils.get_input_tensor_by_name(
+                request, "prompt_speech_tokens")
+            if prompt_speech_tokens_pb is not None:
+                prompt_speech_tokens = torch.utils.dlpack.from_dlpack(
+                    prompt_speech_tokens_pb.to_dlpack()).to(self.device)
+                if prompt_speech_tokens.dim() == 1:
+                    prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
+
+                prompt_speech_feat = pb_utils.get_input_tensor_by_name(
+                    request, "prompt_speech_feat")
+                prompt_speech_feat = torch.utils.dlpack.from_dlpack(
+                    prompt_speech_feat.to_dlpack()).to(self.device)
+                if prompt_speech_feat.dim() == 2:
+                    prompt_speech_feat = prompt_speech_feat.unsqueeze(0)  # [T, 80] -> [1, T, 80]
+
+                prompt_spk_embedding = pb_utils.get_input_tensor_by_name(
+                    request, "prompt_spk_embedding")
+                prompt_spk_embedding = torch.utils.dlpack.from_dlpack(
+                    prompt_spk_embedding.to_dlpack()).to(self.device)
+                if prompt_spk_embedding.dim() == 1:
+                    prompt_spk_embedding = prompt_spk_embedding.unsqueeze(0)
+            else:
+                raise ValueError("prompt_speech_tokens is required for CosyVoice3 token2wav")
+
+            token_offset_pb = pb_utils.get_input_tensor_by_name(request, "token_offset")
+            finalize_pb = pb_utils.get_input_tensor_by_name(request, "finalize")
+
+            token_offset = token_offset_pb.as_numpy().item() if token_offset_pb is not None else None
+            finalize = finalize_pb.as_numpy().item() if finalize_pb is not None else True
+            streaming = not finalize
+
+            with torch.no_grad(), torch.cuda.amp.autocast(self.fp16):
+                mel, _ = self.flow.inference(
+                    token=target_speech_tokens,
+                    token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(self.device),
+                    prompt_token=prompt_speech_tokens,
+                    prompt_token_len=torch.tensor([prompt_speech_tokens.shape[1]], dtype=torch.int32).to(self.device),
+                    prompt_feat=prompt_speech_feat,
+                    prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
+                    embedding=prompt_spk_embedding,
+                    streaming=streaming,
+                    finalize=finalize,
+                )
+
+            # Slice mel from token_offset if provided
+            if token_offset is not None:
+                mel = mel[:, :, token_offset * self.token_mel_ratio:]
+
+            # Output mel as [80, T] (squeeze batch dim for Triton)
+            mel_out = mel.squeeze(0).float()  # [80, T]
+            mel_out = mel_out.cpu() # otherwise, dlpack bug
+            mel_tensor = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel_out))
+            inference_response = pb_utils.InferenceResponse(output_tensors=[mel_tensor])
+            responses.append(inference_response)
+
+        return responses

+ 71 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/token2wav/config.pbtxt

@@ -0,0 +1,71 @@
+name: "token2wav"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+    priority_levels: 100
+    default_priority_level: 100
+}
+
+parameters [
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "target_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+  },
+  {
+    name: "prompt_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+    optional: true
+  },
+  {
+    name: "prompt_speech_feat"
+    data_type: TYPE_FP16
+    dims: [-1, 80]
+    optional: true
+  },
+  {
+    name: "prompt_spk_embedding"
+    data_type: TYPE_FP16
+    dims: [-1]
+    optional: true
+  },
+  {
+    name: "token_offset"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "finalize"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  }
+]
+output [
+  {
+    name: "mel"
+    data_type: TYPE_FP32
+    dims: [ 80, -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_GPU
+    gpus: [ 0 ]
+  }
+]

+ 69 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/vocoder/1/model.py

@@ -0,0 +1,69 @@
+import json
+import os
+import logging
+
+import torch
+from torch.utils.dlpack import to_dlpack
+import triton_python_backend_utils as pb_utils
+from hyperpyyaml import load_hyperpyyaml
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+torch.set_num_threads(1)
+
+
+class TritonPythonModel:
+    """Triton Python model for CosyVoice3 vocoder (CausalHiFTGenerator).
+
+    Stateless: converts mel spectrogram to waveform.
+    CausalHiFTGenerator manages its own internal cache.
+    """
+
+    def initialize(self, args):
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+        model_dir = model_params["model_dir"]
+
+        self.device = torch.device("cuda")
+
+        # Load CausalHiFTGenerator from cosyvoice3.yaml
+        with open(os.path.join(model_dir, 'cosyvoice3.yaml'), 'r') as f:
+            configs = load_hyperpyyaml(f, overrides={
+                'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
+            })
+        self.hift = configs['hift']
+        hift_state_dict = {
+            k.replace('generator.', ''): v
+            for k, v in torch.load(
+                os.path.join(model_dir, 'hift.pt'),
+                map_location='cpu', weights_only=True
+            ).items()
+        }
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+        logger.info("CausalHiFTGenerator initialized successfully")
+
+    def execute(self, requests):
+        responses = []
+        for req_idx, request in enumerate(requests):
+            mel = pb_utils.get_input_tensor_by_name(request, "mel")
+            mel = torch.utils.dlpack.from_dlpack(mel.to_dlpack()).to(self.device)
+            if mel.dim() == 2:
+                mel = mel.unsqueeze(0)  # [80, T] -> [1, 80, T]
+
+            finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
+
+            with torch.no_grad():
+                speech, _ = self.hift.inference(speech_feat=mel, finalize=finalize)
+
+            # speech shape: [1, 1, S] or [1, S] depending on hift version
+            speech = speech.squeeze()  # flatten to [S]
+
+            speech_tensor = pb_utils.Tensor.from_dlpack(
+                "tts_speech", to_dlpack(speech.unsqueeze(0)))  # [1, S] for batch dim
+            inference_response = pb_utils.InferenceResponse(
+                output_tensors=[speech_tensor])
+            responses.append(inference_response)
+
+        return responses

+ 40 - 0
runtime/triton_trtllm/model_repo_cosyvoice3/vocoder/config.pbtxt

@@ -0,0 +1,40 @@
+name: "vocoder"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+parameters [
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "mel"
+    data_type: TYPE_FP32
+    dims: [80, -1]
+  },
+  {
+    name: "finalize"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+  }
+]
+output [
+  {
+    name: "tts_speech"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 252 - 0
runtime/triton_trtllm/run_cosyvoice3.sh

@@ -0,0 +1,252 @@
+#!/bin/bash
+# Copyright (c) 2026 NVIDIA (authors: Yuekai Zhang)
+export CUDA_VISIBLE_DEVICES=0
+# cosyvoice_path=/workspace/CosyVoice
+cosyvoice_path=/workspace_yuekai/tts/CosyVoice
+
+export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
+export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
+
+stage=$1
+stop_stage=$2
+
+huggingface_model_local_dir=./hf_cosyvoice3_llm
+model_scope_model_local_dir=/workspace_yuekai/HF/Fun-CosyVoice3-0.5B-2512
+
+trt_dtype=bfloat16
+trt_weights_dir=./trt_weights_${trt_dtype}
+trt_engines_dir=./trt_engines_${trt_dtype}
+
+model_repo_src=./model_repo_cosyvoice3
+model_repo=./deploy_cosyvoice3
+bls_instance_num=1
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+
+    echo "Cloning CosyVoice"
+    git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
+    cd $cosyvoice_path
+    git submodule update --init --recursive
+    cd runtime/triton_trtllm
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+    echo ""
+    # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
+    # huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
+    # modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
+
+    # pip3 install --upgrade x_transformers s3tokenizer 
+    # pip install -U nvidia-modelopt[all]
+    python3 scripts/convert_cosyvoice3_to_hf.py \
+        --model-dir $model_scope_model_local_dir \
+        --output-dir $huggingface_model_local_dir || exit 1 # TODO: output dir should be here
+
+fi
+
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+    echo "Converting checkpoint to TensorRT weights"
+    python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
+                                --output_dir $trt_weights_dir \
+                                --dtype $trt_dtype || exit 1
+
+    echo "Building TensorRT engines"
+    trtllm-build --checkpoint_dir $trt_weights_dir \
+                --output_dir $trt_engines_dir \
+                --max_batch_size 64 \
+                --max_num_tokens 32768 \
+                --gemm_plugin $trt_dtype || exit 1
+
+    echo "Testing TensorRT engines"
+    python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
+                    --tokenizer_dir $huggingface_model_local_dir \
+                    --top_k 50 --top_p 0.95 --temperature 0.8 \
+                    --engine_dir=$trt_engines_dir  || exit 1
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+    echo "Creating CosyVoice3 model repository"
+    rm -rf $model_repo
+    mkdir -p $model_repo
+
+    # Copy all modules from template source
+    cp -r ${model_repo_src}/cosyvoice3 $model_repo/
+    cp -r ${model_repo_src}/token2wav $model_repo/
+    cp -r ${model_repo_src}/vocoder $model_repo/
+    cp -r ${model_repo_src}/audio_tokenizer $model_repo/
+    cp -r ${model_repo_src}/speaker_embedding $model_repo/
+
+    MAX_QUEUE_DELAY_MICROSECONDS=0
+    MODEL_DIR=$model_scope_model_local_dir
+    LLM_TOKENIZER_DIR=$huggingface_model_local_dir
+    BLS_INSTANCE_NUM=$bls_instance_num
+    TRITON_MAX_BATCH_SIZE=1
+    DECOUPLED_MODE=True
+
+    python3 scripts/fill_template.py -i ${model_repo}/cosyvoice3/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+
+fi
+
+if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
+    echo "Starting CosyVoice3 Triton server and LLM using trtllm-serve"
+    CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64  --kv_cache_free_gpu_memory_fraction 0.4
+fi
+
+
+if [ $stage -le 40 ] && [ $stop_stage -ge 40 ]; then
+
+   CUDA_VISIBLE_DEVICES=1 tritonserver --model-repository $model_repo --http-port 18000 --grpc-port 18001 --metrics-port 18002 &
+fi
+
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+   echo "Starting CosyVoice3 Triton server and LLM using trtllm-serve"
+   CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64  --kv_cache_free_gpu_memory_fraction 0.4 &
+   CUDA_VISIBLE_DEVICES=0,1,2,3 tritonserver --model-repository $model_repo --http-port 18000 --grpc-port 18001 --metrics-port 18002 &
+   wait
+    # Test using curl
+    # curl http://localhost:8000/v1/chat/completions \
+    #     -H "Content-Type: application/json" \
+    #     -d '{
+    #         "model": "",
+    #         "messages":[{"role": "user", "content": "Where is New York?"},
+    #                     {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
+    #         "max_tokens": 512,
+    #         "temperature": 0.8,
+    #         "top_p": 0.95,
+    #         "top_k": 50,
+    #         "stop": ["<|eos1|>"],
+    #         "repetition_penalty": 1.2,
+    #         "stream": false
+    #     }'
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+    echo "Running benchmark client for CosyVoice3"
+    num_task=4
+    mode=offline
+    mode=streaming
+    BLS_INSTANCE_NUM=$bls_instance_num
+
+    python3 client_grpc.py \
+        --server-addr localhost \
+        --server-port 18001 \
+        --model-name cosyvoice3 \
+        --num-tasks $num_task \
+        --mode $mode \
+        --huggingface-dataset yuekai/seed_tts_cosy2 \
+        --log-dir ./log_cosyvoice3_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
+
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
+
+  datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
+  backend=trtllm # hf, trtllm, vllm, trtllm-serve
+
+  batch_sizes=(16)
+  token2wav_batch_size=1
+
+  for batch_size in ${batch_sizes[@]}; do
+    for dataset in ${datasets[@]}; do
+    output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
+    CUDA_VISIBLE_DEVICES=1 \
+        python3 offline_inference.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $step_audio_model_dir/token2wav \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+    done
+  done
+fi
+
+
+
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+   echo "Disaggregated Server: LLM and Token2wav on different GPUs"
+   echo "Starting LLM server on GPU 0"
+   export CUDA_VISIBLE_DEVICES=0
+   mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64  --kv_cache_free_gpu_memory_fraction 0.4 &
+   echo "Starting Token2wav server on GPUs 1-3"
+   Token2wav_num_gpus=3
+   http_port=17000
+   grpc_port=18000
+   metrics_port=16000
+   for i in $(seq 0 $(($Token2wav_num_gpus - 1))); do
+       echo "Starting server on GPU $i"
+       http_port=$((http_port + 1))
+       grpc_port=$((grpc_port + 1))
+       metrics_port=$((metrics_port + 1))
+       # Two instances of Token2wav server on the same GPU
+       CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
+       http_port=$((http_port + 1))
+       grpc_port=$((grpc_port + 1))
+       metrics_port=$((metrics_port + 1))
+       CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
+   done
+   wait
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+    echo "Running benchmark client for Disaggregated Server"
+    per_gpu_instances=2
+    mode=streaming
+    BLS_INSTANCE_NUM=$bls_instance_num
+    Token2wav_num_gpus=(1 2 3)
+    concurrent_tasks=(1 2 3 4 5 6)
+    for n_gpu in ${Token2wav_num_gpus[@]}; do
+        echo "Test 1 GPU for LLM server and $n_gpu GPUs for Token2wav servers"
+        for concurrent_task in ${concurrent_tasks[@]}; do
+            num_instances=$((per_gpu_instances * n_gpu))
+            for i in $(seq 1 $num_instances); do
+                port=$(($i + 18000))
+                python3 client_grpc.py \
+                    --server-addr localhost \
+                    --server-port $port \
+                    --model-name cosyvoice2_dit \
+                    --num-tasks $concurrent_task \
+                    --mode $mode \
+                    --huggingface-dataset yuekai/seed_tts_cosy2 \
+                    --log-dir ./log_disagg_concurrent_tasks_${concurrent_task}_per_instance_total_token2wav_instances_${num_instances}_port_${port} &
+            done
+            wait
+        done
+    done
+fi
+
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+    echo "stage 10: Python script CosyVoice3 TTS (LLM + CosyVoice3 Token2Wav) inference"
+
+    datasets=(wenetspeech4tts) # wenetspeech4tts
+    backend=trtllm-serve  # hf, trtllm, vllm, trtllm-serve
+
+    batch_sizes=(1)
+    token2wav_batch_size=1
+
+    for batch_size in ${batch_sizes[@]}; do
+      for dataset in ${datasets[@]}; do
+        output_dir=./cosyvoice3_${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}_streaming_trt
+        CUDA_VISIBLE_DEVICES=0 \
+            python3 infer_cosyvoice3.py \
+                --output-dir $output_dir \
+                --llm-model-name-or-path $huggingface_model_local_dir \
+                --token2wav-path $model_scope_model_local_dir \
+                --backend $backend \
+                --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+                --engine-dir $trt_engines_dir \
+                --enable-trt --streaming\
+                --epoch 1 \
+                --split-name ${dataset} || exit 1
+      done
+    done
+fi

+ 373 - 0
runtime/triton_trtllm/scripts/convert_cosyvoice3_to_hf.py

@@ -0,0 +1,373 @@
+#!/usr/bin/env python3
+# Copyright 2025 CosyVoice3 TRT-LLM Integration
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Конвертация CosyVoice3 LLM в HuggingFace формат с объединёнными embeddings.
+
+Этот скрипт:
+1. Загружает CosyVoice3 модель
+2. Расширяет vocab токенизатора с speech токенами
+3. Объединяет speech_embedding в embed_tokens Qwen2
+4. Заменяет lm_head на llm_decoder с расширенным vocab
+5. Сохраняет модель в HuggingFace формате для TRT-LLM конвертации
+
+Usage:
+    python scripts/convert_cosyvoice3_to_hf.py \
+        --model-dir pretrained_models/Fun-CosyVoice3-0.5B \
+        --output-dir pretrained_models/Fun-CosyVoice3-0.5B/hf_merged
+
+После этого можно конвертировать в TRT-LLM:
+    trtllm-build --checkpoint_dir <output_dir> --output_dir <trt_engines_dir> ...
+"""
+import argparse
+import os
+import sys
+import logging
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'third_party/Matcha-TTS'))
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Convert CosyVoice3 to HuggingFace format with merged embeddings")
+    parser.add_argument(
+        "--model-dir",
+        type=str,
+        default="pretrained_models/Fun-CosyVoice3-0.5B",
+        help="Path to CosyVoice3 model directory",
+    )
+    parser.add_argument(
+        "--output-dir",
+        type=str,
+        default=None,
+        help="Output directory for HuggingFace model (default: <model-dir>/hf_merged)",
+    )
+    parser.add_argument(
+        "--dtype",
+        type=str,
+        default="bfloat16",
+        choices=["float16", "bfloat16", "float32"],
+        help="Output dtype for the model",
+    )
+    return parser.parse_args()
+
+
+def load_cosyvoice3_model(model_dir: str):
+    """Загружает CosyVoice3 модель для извлечения весов."""
+    from hyperpyyaml import load_hyperpyyaml
+    from cosyvoice.utils.class_utils import get_model_type
+    
+    hyper_yaml_path = os.path.join(model_dir, 'cosyvoice3.yaml')
+    hf_llm_dir = os.path.join(model_dir, 'CosyVoice-BlankEN')
+    
+    if not os.path.exists(hyper_yaml_path):
+        raise ValueError(f'{hyper_yaml_path} not found!')
+    
+    with open(hyper_yaml_path, 'r') as f:
+        configs = load_hyperpyyaml(
+            f, 
+            overrides={'qwen_pretrain_path': hf_llm_dir}
+        )
+    
+    # Загружаем только LLM
+    llm = configs['llm']
+    llm_weights_path = os.path.join(model_dir, 'llm.pt')
+    llm.load_state_dict(torch.load(llm_weights_path, map_location='cpu'), strict=True)
+    llm.eval()
+    
+    logger.info(f"Loaded CosyVoice3 LLM from {model_dir}")
+    
+    return llm, hf_llm_dir, configs
+
+
+def get_speech_token_size(llm) -> int:
+    """Определяет размер speech token vocabulary из модели."""
+    # CosyVoice3LM имеет: speech_token_size + 200 в llm_decoder
+    # speech_embedding имеет: speech_token_size + 200
+    speech_embedding_size = llm.speech_embedding.num_embeddings
+    # Вычитаем 200 специальных токенов (sos, eos, task_id, fill, и т.д.)
+    # Но для безопасности используем полный размер embedding
+    return speech_embedding_size
+
+
+def convert_cosyvoice3_to_hf(
+    model_dir: str,
+    output_dir: str,
+    dtype: str = "bfloat16",
+):
+    """
+    Конвертирует CosyVoice3 LLM в HuggingFace формат с объединёнными embeddings.
+    
+    Архитектура объединения:
+    - embed_tokens[0:original_vocab_size] = оригинальные text embeddings
+    - embed_tokens[original_vocab_size:original_vocab_size+speech_token_size] = speech_embedding
+    - lm_head[original_vocab_size:original_vocab_size+speech_token_size] = llm_decoder
+    
+    Args:
+        model_dir: Путь к CosyVoice3 модели
+        output_dir: Путь для сохранения HF модели
+        dtype: Тип данных для сохранения
+    """
+    logger.info(f"Loading CosyVoice3 model from {model_dir}")
+    
+    # 1. Загружаем CosyVoice3 компоненты
+    cosyvoice3_llm, hf_llm_dir, configs = load_cosyvoice3_model(model_dir)
+    
+    # Извлекаем ключевые компоненты
+    qwen_model = cosyvoice3_llm.llm.model  # Qwen2ForCausalLM
+    speech_embedding = cosyvoice3_llm.speech_embedding  # Embedding для speech токенов
+    llm_decoder = cosyvoice3_llm.llm_decoder  # Linear для декодирования в speech токены
+    
+    speech_token_size = get_speech_token_size(cosyvoice3_llm)
+    logger.info(f"Speech token size: {speech_token_size}")
+    
+    # 2. Загружаем tokenizer и добавляем CosyVoice3 text special tokens + speech токены
+    tokenizer = AutoTokenizer.from_pretrained(hf_llm_dir, trust_remote_code=True)
+    base_vocab_size = len(tokenizer)
+    logger.info(f"Base tokenizer vocab size: {base_vocab_size}")
+    
+    # IMPORTANT:
+    # - In CosyVoice3, LLM speech special tokens (sos/eos/task_id/fill) are INSIDE speech_embedding,
+    #   i.e. represented as <|s_6561|>, <|s_6562|>, <|s_6563|>, <|s_6564|>.
+    # - But text-level special tokens like [cough]/[laughter] MUST exist in tokenizer
+    #   (mirrors `CosyVoice3Tokenizer` from `cosyvoice/tokenizer/tokenizer.py`).
+    special_tokens = {
+        'eos_token': '<|endoftext|>',
+        'pad_token': '<|endoftext|>',
+        'additional_special_tokens': [
+            '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
+            '[breath]', '<strong>', '</strong>', '[noise]',
+            '[laughter]', '[cough]', '[clucking]', '[accent]',
+            '[quick_breath]',
+            "<laughter>", "</laughter>",
+            "[hissing]", "[sigh]", "[vocalized-noise]",
+            "[lipsmack]", "[mn]", "<|endofsystem|>",
+            # Phoneme tokens (kept consistent with CosyVoice3Tokenizer)
+            "[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
+            "[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
+            "[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
+            "[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
+            "[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
+            "[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
+            "[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
+            "[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
+            "[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
+            "[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
+            "[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
+            "[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]",
+            "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
+            "[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
+            "[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
+            "[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
+            "[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
+            "[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
+            "[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
+            "[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
+            "[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
+        ]
+    }
+    tokenizer.add_special_tokens(special_tokens)
+    text_vocab_size = len(tokenizer)
+    logger.info(f"Tokenizer vocab after CosyVoice3 text special tokens: {text_vocab_size}")
+    
+    # Add speech tokens: <|s_0|>, <|s_1|>, ..., <|s_{embedding_size-1}|>
+    # IMPORTANT: This range must match speech_embedding.num_embeddings (includes speech special tokens).
+    actual_speech_tokens = speech_token_size  # Full embedding size (with speech special tokens)
+
+    # replace <s_6561> to <|sos|>
+    # replace <s_6562> to <|eos1|>
+    # replace <s_6563> to <|task_id|>
+    # replace <s_6564> to <|fill|>
+    speech_tokens = [f"<|s_{i}|>" for i in range(actual_speech_tokens)]
+    speech_tokens[6561] = "<|sos|>"
+    speech_tokens[6562] = "<|eos1|>"
+    speech_tokens[6563] = "<|task_id|>"
+    speech_tokens[6564] = "<|fill|>"
+    assert "<s_6561>" not in speech_tokens
+    assert "<s_6562>" not in speech_tokens
+    assert "<s_6563>" not in speech_tokens
+    assert "<s_6564>" not in speech_tokens
+    tokenizer.add_tokens(speech_tokens)
+    
+    new_vocab_size = len(tokenizer)
+    logger.info(f"New tokenizer vocab size: {new_vocab_size}")
+    logger.info(f"Added {new_vocab_size - base_vocab_size} tokens total (text special + speech tokens)")
+    
+    # 3. Изменяем размер embeddings в Qwen модели
+    # Выравниваем по 128 для эффективности TensorRT
+    padded_vocab_size = ((new_vocab_size + 127) // 128) * 128
+    qwen_model.resize_token_embeddings(padded_vocab_size)
+    logger.info(f"Resized embeddings to: {padded_vocab_size}")
+    
+    # Speech tokens start after text vocab (base + CosyVoice3 text special tokens)
+    speech_token_offset = text_vocab_size
+
+    # 4. Копируем speech_embedding в расширенную часть embed_tokens
+    input_embeddings = qwen_model.get_input_embeddings()
+    hidden_size = input_embeddings.weight.shape[1]
+    
+    logger.info(f"Hidden size: {hidden_size}")
+    logger.info(f"speech_embedding shape: {speech_embedding.weight.shape}")
+    logger.info(f"llm_decoder shape: {llm_decoder.weight.shape}")
+    
+    with torch.no_grad():
+        # Копируем speech_embedding веса в embed_tokens
+        # Indices: [speech_token_offset, speech_token_offset + speech_token_size)
+        src_size = min(speech_embedding.weight.shape[0], actual_speech_tokens)
+        input_embeddings.weight[speech_token_offset:speech_token_offset + src_size] = \
+            speech_embedding.weight[:src_size].to(input_embeddings.weight.dtype)
+    
+    logger.info(f"Copied speech_embedding to embed_tokens[{speech_token_offset}:{speech_token_offset + src_size}]")
+    
+    # 5. Создаём новый lm_head с расширенным vocab и копируем llm_decoder
+    # Оригинальный lm_head: hidden_size -> original_vocab_size
+    # Новый lm_head: hidden_size -> padded_vocab_size
+    # llm_decoder: hidden_size -> speech_token_size
+    
+    # Создаём новый lm_head
+    has_bias = llm_decoder.bias is not None
+    new_lm_head = torch.nn.Linear(
+        in_features=hidden_size,
+        out_features=padded_vocab_size,
+        bias=has_bias
+    )
+    
+    with torch.no_grad():
+        # Инициализируем веса:
+        # - Text часть: копируем из оригинального lm_head (или нули)
+        # - Speech часть: копируем из llm_decoder
+        # - Padding: нули
+        
+        # Сначала заполняем нулями и -inf в bias (чтобы text токены не генерировались)
+        new_lm_head.weight.data.zero_()
+        if has_bias:
+            new_lm_head.bias.data.fill_(-float('inf'))
+        
+        # Копируем оригинальный lm_head для text токенов (опционально)
+        original_lm_head = qwen_model.lm_head
+        if original_lm_head is not None and original_lm_head.weight.shape[0] >= text_vocab_size:
+            new_lm_head.weight[:text_vocab_size] = original_lm_head.weight[:text_vocab_size]
+            if has_bias and original_lm_head.bias is not None:
+                new_lm_head.bias[:text_vocab_size] = original_lm_head.bias[:text_vocab_size]
+        
+        # Копируем llm_decoder для speech токенов
+        decoder_size = min(llm_decoder.weight.shape[0], actual_speech_tokens)
+        new_lm_head.weight[speech_token_offset:speech_token_offset + decoder_size] = \
+            llm_decoder.weight[:decoder_size].to(new_lm_head.weight.dtype)
+        
+        if has_bias:
+            new_lm_head.bias[speech_token_offset:speech_token_offset + decoder_size] = \
+                llm_decoder.bias[:decoder_size].to(new_lm_head.bias.dtype)
+        else:
+            # Если llm_decoder не имеет bias, но мы хотим его для text токенов
+            pass
+    
+    # Заменяем lm_head
+    qwen_model.lm_head = new_lm_head
+    
+    logger.info(f"Created new lm_head with shape: {new_lm_head.weight.shape}")
+    logger.info(f"Copied llm_decoder to lm_head[{speech_token_offset}:{speech_token_offset + decoder_size}]")
+    
+    # 6. Обновляем конфигурацию модели
+    qwen_model.config.vocab_size = padded_vocab_size
+    qwen_model.config.tie_word_embeddings = False  # Embeddings и lm_head теперь разные!
+    
+    # Set EOS token for generation (speech EOS lives inside speech_embedding as <|s_{base_speech_token_size+1}|>)
+    base_speech_token_size = getattr(cosyvoice3_llm, "speech_token_size", 6561)
+    eos_speech_idx = base_speech_token_size + 1
+    eos_id = speech_token_offset + eos_speech_idx
+    qwen_model.config.eos_token_id = eos_id
+    
+    # Настройки генерации
+    qwen_model.generation_config.eos_token_id = eos_id
+    qwen_model.generation_config.pad_token_id = eos_id
+    qwen_model.generation_config.temperature = 0.8
+    qwen_model.generation_config.top_p = 0.95
+    qwen_model.generation_config.top_k = 25
+    qwen_model.generation_config.repetition_penalty = 1.1
+    qwen_model.generation_config.max_new_tokens = 2048
+    
+    # 7. Конвертируем в нужный dtype
+    dtype_map = {
+        "float16": torch.float16,
+        "bfloat16": torch.bfloat16,
+        "float32": torch.float32,
+    }
+    target_dtype = dtype_map[dtype]
+    qwen_model.to(target_dtype)
+    
+    # 8. Сохраняем модель и tokenizer
+    os.makedirs(output_dir, exist_ok=True)
+    
+    qwen_model.save_pretrained(output_dir)
+    
+    TEMPLATE = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- '<|sos|>' + message['content'] + '<|task_id|>' }}{%- elif message['role'] == 'assistant' %}{{- message['content']}}{%- endif %}{%- endfor %}"
+    tokenizer.chat_template = TEMPLATE
+    tokenizer.save_pretrained(output_dir)
+    
+    # Сохраняем метаданные для TRT-LLM inference
+    metadata = {
+        "original_vocab_size": base_vocab_size,
+        "text_vocab_size": text_vocab_size,
+        "base_speech_token_size": base_speech_token_size,
+        "embedding_size": actual_speech_tokens,
+        "padded_vocab_size": padded_vocab_size,
+        "eos_token_id": eos_id,
+        "speech_token_offset": speech_token_offset,
+        "dtype": dtype,
+    }
+    
+    import json
+    with open(os.path.join(output_dir, "cosyvoice3_metadata.json"), "w") as f:
+        json.dump(metadata, f, indent=2)
+    
+    logger.info(f"Saved HuggingFace model to {output_dir}")
+    logger.info(f"Metadata: {metadata}")
+    
+    return output_dir, metadata
+
+
+def main():
+    args = parse_args()
+    
+    output_dir = args.output_dir
+    if output_dir is None:
+        output_dir = os.path.join(args.model_dir, "hf_merged")
+    
+    convert_cosyvoice3_to_hf(
+        model_dir=args.model_dir,
+        output_dir=output_dir,
+        dtype=args.dtype,
+    )
+    
+    print("\n" + "=" * 70)
+    print("✅ Conversion complete!")
+    print("=" * 70)
+    print(f"\nHuggingFace model saved to: {output_dir}")
+    print("\nNext steps:")
+    print("1. Convert to TRT-LLM weights:")
+    print(f"   python -c \"from tensorrt_llm.models import QWenForCausalLM; ...")
+    print("\n2. Build TRT-LLM engines:")
+    print(f"   trtllm-build --checkpoint_dir <trt_weights_dir> --output_dir <trt_engines_dir> ...")
+    print("=" * 70)
+
+
+if __name__ == "__main__":
+    main()

+ 414 - 0
runtime/triton_trtllm/token2wav_cosyvoice3.py

@@ -0,0 +1,414 @@
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 token2wav_cosyvoice3.py --enable-trt || exit 1
+"""
+import torch
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+import onnxruntime
+import s3tokenizer
+import os
+import logging
+import argparse
+import queue
+import time
+import numpy as np
+from functools import partial
+from hyperpyyaml import load_hyperpyyaml
+from matcha.utils.audio import mel_spectrogram as matcha_mel_spectrogram
+from torch.utils.data import DataLoader
+from datasets import load_dataset
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+# CosyVoice3 mel params from cosyvoice3.yaml (fmax=None, NOT 8000)
+mel_spectrogram = partial(matcha_mel_spectrogram,
+    n_fft=1920, num_mels=80, sampling_rate=24000,
+    hop_size=480, win_size=1920, fmin=0, fmax=None, center=False)
+
+
+def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocast_mode=False):
+    import tensorrt as trt
+    logging.info("Converting onnx to trt...")
+    if autocast_mode:
+        network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
+    else:
+        network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+    logger = trt.Logger(trt.Logger.INFO)
+    builder = trt.Builder(logger)
+    network = builder.create_network(network_flags)
+    parser = trt.OnnxParser(network, logger)
+    config = builder.create_builder_config()
+    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB
+    if not autocast_mode:
+        if fp16:
+            config.set_flag(trt.BuilderFlag.FP16)
+    profile = builder.create_optimization_profile()
+    # load onnx model
+    with open(onnx_model, "rb") as f:
+        if not parser.parse(f.read()):
+            for error in range(parser.num_errors):
+                print(parser.get_error(error))
+            raise ValueError('failed to parse {}'.format(onnx_model))
+    # set input shapes
+    for i in range(len(trt_kwargs['input_names'])):
+        profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
+    tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
+    # set input and output data type
+    for i in range(network.num_inputs):
+        input_tensor = network.get_input(i)
+        input_tensor.dtype = tensor_dtype
+    for i in range(network.num_outputs):
+        output_tensor = network.get_output(i)
+        output_tensor.dtype = tensor_dtype
+    config.add_optimization_profile(profile)
+    engine_bytes = builder.build_serialized_network(network, config)
+    # save trt engine
+    with open(trt_model, "wb") as f:
+        f.write(engine_bytes)
+    logging.info("Succesfully convert onnx to trt...")
+
+
+class TrtContextWrapper:
+    def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        self.trt_engine = trt_engine
+        self.device = device
+        for _ in range(trt_concurrent):
+            trt_context = trt_engine.create_execution_context()
+            trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
+            assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
+            self.trt_context_pool.put([trt_context, trt_stream])
+        assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
+
+    def acquire_estimator(self):
+        return self.trt_context_pool.get(), self.trt_engine
+
+    def release_estimator(self, context, stream):
+        self.trt_context_pool.put([context, stream])
+
+
+class CosyVoice3_Token2Wav(torch.nn.Module):
+    def __init__(self, model_dir, enable_trt=False, device_id=0, autocast_mode=True, streaming=False):
+        super().__init__()
+        self.device_id = device_id
+        self.device = f"cuda:{device_id}"
+        self.autocast_mode = autocast_mode
+        self.streaming = streaming
+
+        # Load flow and hift from cosyvoice3.yaml
+        with open(f"{model_dir}/cosyvoice3.yaml", "r") as f:
+            configs = load_hyperpyyaml(f, overrides={
+                'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
+            })
+        self.flow = configs['flow']
+        self.flow.load_state_dict(
+            torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True),
+            strict=True
+        )
+        self.flow.to(self.device).eval()
+
+        self.hift = configs['hift']
+        hift_state_dict = {
+            k.replace('generator.', ''): v
+            for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()
+        }
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+        # Speaker embedding model (campplus)
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        self.spk_model = onnxruntime.InferenceSession(
+            f"{model_dir}/campplus.onnx", sess_options=option,
+            providers=["CPUExecutionProvider"]
+        )
+
+        # Audio tokenizer v3
+        self.audio_tokenizer = s3tokenizer.load_model(
+            f"{model_dir}/speech_tokenizer_v3.onnx"
+        ).to(self.device).eval()
+
+        self.fp16 = enable_trt
+        if enable_trt:
+            self.flow.half()
+            self.load_trt(model_dir)
+            self.load_spk_trt(model_dir)
+
+    def load_trt(self, model_dir, trt_concurrent=1):
+        streaming_prefix = 'streaming.' if self.streaming else ''
+        if self.autocast_mode:
+            onnx_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}autocast_fp16.onnx'
+            trt_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}autocast_fp16.{self.device_id}.plan'
+        else:
+            onnx_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}fp32.onnx'
+            trt_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}fp32.{self.device_id}.plan'
+
+        if not os.path.exists(trt_path) or os.path.getsize(trt_path) == 0:
+            trt_kwargs = self.get_trt_kwargs()
+            convert_onnx_to_trt(trt_path, trt_kwargs, onnx_path,
+                               fp16=True, autocast_mode=self.autocast_mode)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(trt_path, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(trt_path)
+        self.flow.decoder.estimator = TrtContextWrapper(
+            estimator_engine, trt_concurrent=trt_concurrent, device=self.device
+        )
+
+    def get_trt_kwargs(self):
+        # CosyVoice3 DiT estimator has 6 inputs: x, mask, mu, t, spks, cond
+        # Only inputs with dynamic dims need optimization profiles.
+        # t=[2(fixed)] and spks=[2(fixed),80(fixed)] are fully fixed, TRT infers from ONNX.
+        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
+        opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
+        max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
+        input_names = ["x", "mask", "mu", "cond"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape,
+                'max_shape': max_shape, 'input_names': input_names}
+
+    def load_spk_trt(self, model_dir, trt_concurrent=1, fp16=False):
+        spk_trt_path = f'{model_dir}/campplus.{self.device_id}.fp32.plan'
+        spk_onnx_path = f'{model_dir}/campplus.onnx'
+        if not os.path.exists(spk_trt_path) or os.path.getsize(spk_trt_path) == 0:
+            trt_kwargs = self.get_spk_trt_kwargs()
+            convert_onnx_to_trt(spk_trt_path, trt_kwargs, spk_onnx_path, fp16)
+        import tensorrt as trt
+        with open(spk_trt_path, 'rb') as f:
+            spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert spk_engine is not None, 'failed to load trt {}'.format(spk_trt_path)
+        self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_spk_trt_kwargs(self):
+        min_shape = [(1, 4, 80)]
+        opt_shape = [(1, 500, 80)]
+        max_shape = [(1, 3000, 80)]
+        input_names = ["input"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape,
+                'max_shape': max_shape, 'input_names': input_names}
+
+    def forward_spk_embedding(self, spk_feat):
+        if isinstance(self.spk_model, onnxruntime.InferenceSession):
+            return self.spk_model.run(
+                None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
+            )[0].flatten().tolist()
+        else:
+            [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
+            with torch.cuda.device(self.device_id):
+                torch.cuda.current_stream().synchronize()
+                spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
+                batch_size = spk_feat.size(0)
+
+                with stream:
+                    spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
+                    output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
+
+                    data_ptrs = [spk_feat.contiguous().data_ptr(),
+                                 output_tensor.contiguous().data_ptr()]
+                    for i, j in enumerate(data_ptrs):
+                        spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                    assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                    torch.cuda.current_stream().synchronize()
+                self.spk_model.release_estimator(spk_model, stream)
+
+            return output_tensor.cpu().numpy().flatten().tolist()
+
+    def prompt_audio_tokenization(self, prompt_audios_list):
+        prompt_speech_tokens_list, prompt_speech_mels_list = [], []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            log_mel = s3tokenizer.log_mel_spectrogram(audio)
+            prompt_speech_mels_list.append(log_mel)
+        prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
+        prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
+            prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
+        )
+        for i in range(len(prompt_speech_tokens)):
+            speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
+            prompt_speech_tokens_list.append(speech_tokens_i)
+        return prompt_speech_tokens_list
+
+    def get_spk_emb(self, prompt_audios_list):
+        spk_emb_for_flow = []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
+            spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
+            spk_emb = self.forward_spk_embedding(spk_feat)
+            spk_emb_for_flow.append(spk_emb)
+        spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
+        return spk_emb_for_flow
+
+    def get_prompt_mels(self, prompt_audios_list, prompt_audios_sample_rate):
+        prompt_mels_for_flow = []
+        prompt_mels_lens_for_flow = []
+        for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
+            assert len(audio.shape) == 1
+            audio = audio.unsqueeze(0)
+            if sample_rate != 24000:
+                audio = torchaudio.transforms.Resample(
+                    orig_freq=sample_rate, new_freq=24000)(audio)
+            # CosyVoice3: fmax=None (Nyquist), matching cosyvoice3.yaml
+            mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0)  # [T, 80]
+            prompt_mels_for_flow.append(mel)
+            prompt_mels_lens_for_flow.append(mel.shape[0])
+        prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
+            prompt_mels_for_flow, batch_first=True, padding_value=0)  # [B, T', 80]
+        prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
+        return prompt_mels_for_flow, prompt_mels_lens_for_flow
+
+    def forward_flow(self, prompt_speech_tokens_list, generated_speech_tokens_list,
+                     prompt_mels_for_flow, prompt_mels_lens_for_flow,
+                     spk_emb_for_flow):
+        batch_size = len(generated_speech_tokens_list)
+        generated_mels_list = []
+
+        # CausalMaskedDiffWithDiT.inference asserts batch_size==1, so iterate per-sample
+        for i in range(batch_size):
+            token = torch.tensor([generated_speech_tokens_list[i]]).to(self.device)
+            token_len = torch.tensor([len(generated_speech_tokens_list[i])]).to(self.device)
+            prompt_token = torch.tensor([prompt_speech_tokens_list[i]]).to(self.device)
+            prompt_token_len = torch.tensor([len(prompt_speech_tokens_list[i])]).to(self.device)
+            prompt_feat = prompt_mels_for_flow[i:i+1, :prompt_mels_lens_for_flow[i]].to(self.device)
+            prompt_feat_len = prompt_mels_lens_for_flow[i:i+1].to(self.device)
+            embedding = spk_emb_for_flow[i:i+1].to(self.device)
+
+            # CausalMaskedDiffWithDiT.inference returns mel already without prompt portion
+            with torch.cuda.amp.autocast(self.fp16):
+                mel, _ = self.flow.inference(
+                    token=token,
+                    token_len=token_len,
+                    prompt_token=prompt_token,
+                    prompt_token_len=prompt_token_len,
+                    prompt_feat=prompt_feat,
+                    prompt_feat_len=prompt_feat_len,
+                    embedding=embedding,
+                    streaming=False,
+                    finalize=True
+                )
+            generated_mels_list.append(mel)
+
+        return generated_mels_list
+
+    def forward_hift(self, generated_mels_list):
+        generated_wavs = []
+        for mel in generated_mels_list:
+            # CausalHiFTGenerator.inference with finalize=True
+            wav, _ = self.hift.inference(speech_feat=mel, finalize=True)
+            generated_wavs.append(wav)
+        return generated_wavs
+
+    def forward_stream(self, generated_speech_tokens, prompt_speech_tokens,
+                        prompt_feat, embedding,
+                        token_hop_len=25, stream_scale_factor=2, token_max_hop_len=100):
+        """Streaming token2wav for a single sample: process tokens in chunks."""
+        prompt_token = torch.tensor([prompt_speech_tokens]).to(self.device)
+        prompt_token_len = torch.tensor([len(prompt_speech_tokens)]).to(self.device)
+        prompt_feat = prompt_feat.to(self.device)
+        prompt_feat_len = torch.tensor([prompt_feat.shape[1]]).to(self.device)
+        embedding = embedding.to(self.device)
+
+        pre_lookahead_len = self.flow.pre_lookahead_len
+        token_mel_ratio = self.flow.token_mel_ratio
+
+        # Align first chunk with hop_len boundary
+        prompt_token_pad = int(
+            np.ceil(prompt_token.shape[1] / token_hop_len) * token_hop_len
+            - prompt_token.shape[1]
+        )
+
+        total_tokens = len(generated_speech_tokens)
+        token_offset = 0
+        current_hop = token_hop_len
+        hift_cache_mel = None
+        speech_offset = 0
+        audio_chunks = []
+
+        while token_offset < total_tokens:
+            this_hop = current_hop + prompt_token_pad if token_offset == 0 else current_hop
+            remaining = total_tokens - token_offset
+
+            if remaining >= this_hop + pre_lookahead_len:
+                end_idx = token_offset + this_hop + pre_lookahead_len
+                this_token = torch.tensor([generated_speech_tokens[:end_idx]]).to(self.device)
+                finalize = False
+            else:
+                this_token = torch.tensor([generated_speech_tokens]).to(self.device)
+                finalize = True
+
+            with torch.cuda.amp.autocast(self.fp16):
+                mel, _ = self.flow.inference(
+                    token=this_token,
+                    token_len=torch.tensor([this_token.shape[1]]).to(self.device),
+                    prompt_token=prompt_token,
+                    prompt_token_len=prompt_token_len,
+                    prompt_feat=prompt_feat,
+                    prompt_feat_len=prompt_feat_len,
+                    embedding=embedding,
+                    streaming=True,
+                    finalize=finalize,
+                )
+
+            mel = mel[:, :, token_offset * token_mel_ratio:]
+
+            if hift_cache_mel is not None:
+                mel = torch.concat([hift_cache_mel, mel], dim=2)
+            hift_cache_mel = mel
+
+            tts_speech, _ = self.hift.inference(speech_feat=mel, finalize=finalize)
+            tts_speech = tts_speech[:, speech_offset:]
+            speech_offset += tts_speech.shape[1]
+
+            logger.info(f"[stream] token_offset={token_offset}, this_hop={this_hop}, "
+                        f"mel_shape={mel.shape}, speech_len={tts_speech.shape[1]}, finalize={finalize}")
+
+            audio_chunks.append(tts_speech)
+
+            token_offset += this_hop
+            if not finalize:
+                current_hop = min(token_max_hop_len, current_hop * stream_scale_factor)
+            else:
+                break
+
+        return torch.cat(audio_chunks, dim=1)
+
+    @torch.inference_mode()
+    def forward(self, generated_speech_tokens_list, prompt_audios_list,
+                prompt_audios_sample_rate, streaming=False):
+        assert all(sr == 16000 for sr in prompt_audios_sample_rate)
+
+        prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
+        prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(
+            prompt_audios_list, prompt_audios_sample_rate)
+        spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
+
+        # Align prompt_speech_feat and prompt_speech_token to exact 2:1 ratio
+        # (matches frontend.frontend_zero_shot logic)
+        for i in range(len(prompt_speech_tokens_list)):
+            token_len = min(int(prompt_mels_lens_for_flow[i].item() / 2),
+                            len(prompt_speech_tokens_list[i]))
+            prompt_speech_tokens_list[i] = prompt_speech_tokens_list[i][:token_len]
+            prompt_mels_lens_for_flow[i] = 2 * token_len
+
+        if streaming:
+            generated_wavs = []
+            for i in range(len(generated_speech_tokens_list)):
+                prompt_feat = prompt_mels_for_flow[i:i+1, :prompt_mels_lens_for_flow[i]]
+                embedding = spk_emb_for_flow[i:i+1]
+                wav = self.forward_stream(
+                    generated_speech_tokens_list[i],
+                    prompt_speech_tokens_list[i],
+                    prompt_feat, embedding,
+                )
+                generated_wavs.append(wav)
+            return generated_wavs
+
+        generated_mels_list = self.forward_flow(
+            prompt_speech_tokens_list, generated_speech_tokens_list,
+            prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
+
+        generated_wavs = self.forward_hift(generated_mels_list)
+        return generated_wavs