Преглед изворни кода

Merge pull request #1489 from yuekaizhang/triton

[runtime] Support Cosyvoice2 Nvidia TensorRT-LLM Inference Solution
Xiang Lyu пре 3 месеци
родитељ
комит
1850e2a56e

+ 6 - 0
runtime/triton_trtllm/Dockerfile.server

@@ -0,0 +1,6 @@
+FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3
+RUN apt-get update && apt-get install -y cmake
+RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
+COPY ./requirements.txt /workspace/requirements.txt
+RUN pip install -r /workspace/requirements.txt
+WORKDIR /workspace

+ 89 - 0
runtime/triton_trtllm/README.md

@@ -0,0 +1,89 @@
+## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server
+
+### Quick Start
+Launch the service directly with Docker Compose:
+```sh
+docker compose up
+```
+
+### Build the Docker Image
+Build the image from scratch:
+```sh
+docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
+```
+
+### Run a Docker Container
+```sh
+your_mount_dir=/mnt:/mnt
+docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
+```
+
+### Understanding `run.sh`
+The `run.sh` script orchestrates the entire workflow through numbered stages.
+
+Run a subset of stages with:
+```sh
+bash run.sh <start_stage> <stop_stage> [service_type]
+```
+- `<start_stage>` – stage to start from (0-5).
+- `<stop_stage>`  – stage to stop after (0-5).
+
+Stages:
+- **Stage 0** – Download the cosyvoice-2 0.5B model from HuggingFace.
+- **Stage 1** – Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
+- **Stage 2** – Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later).
+- **Stage 3** – Launch the Triton Inference Server.
+- **Stage 4** – Run the single-utterance HTTP client.
+- **Stage 5** – Run the gRPC benchmark client.
+
+### Export Models to TensorRT-LLM and Launch the Server
+Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
+```sh
+# Runs stages 0, 1, 2, and 3
+bash run.sh 0 3
+```
+*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.*
+
+### Single-Utterance HTTP Client
+Send a single HTTP inference request:
+```sh
+bash run.sh 4 4
+```
+
+### Benchmark with a Dataset
+Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument.
+```sh
+bash run.sh 5 5
+
+# You can also customise parameters such as num_task and dataset split directly:
+# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
+```
+> [!TIP]
+> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Triton’s decoupled mode so that responses are returned as soon as they are ready.
+
+### Benchmark Results
+Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio):
+
+| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
+|------|------|-------------|------------------|------------------|-----|
+| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
+| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
+| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
+| Decoupled=True  | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 |
+| Decoupled=True  | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 |
+| Decoupled=True  | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 |
+
+### OpenAI-Compatible Server
+To launch an OpenAI-compatible service, run:
+```sh
+git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
+pip install -r requirements.txt
+# After the Triton service is up, start the FastAPI bridge:
+python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
+# Test with curl
+bash test/test_cosyvoice.sh
+```
+
+### Acknowledgements
+This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details.
+

+ 834 - 0
runtime/triton_trtllm/client_grpc.py

@@ -0,0 +1,834 @@
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#                2023  Nvidia              (authors: Yuekai Zhang)
+#                2023  Recurrent.ai        (authors: Songtao Shi)
+# See LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script supports to load dataset from huggingface and sends it to the server
+for decoding, in parallel.
+
+Usage:
+num_task=2
+
+# For offline F5-TTS
+python3 client_grpc.py \
+    --server-addr localhost \
+    --model-name f5_tts \
+    --num-tasks $num_task \
+    --huggingface-dataset yuekai/seed_tts \
+    --split-name test_zh \
+    --log-dir ./log_concurrent_tasks_${num_task}
+
+# For offline Spark-TTS-0.5B
+python3 client_grpc.py \
+    --server-addr localhost \
+    --model-name spark_tts \
+    --num-tasks $num_task \
+    --huggingface-dataset yuekai/seed_tts \
+    --split-name wenetspeech4tts \
+    --log-dir ./log_concurrent_tasks_${num_task}
+"""
+
+import argparse
+import asyncio
+import json
+import queue  # Added
+import uuid  # Added
+import functools  # Added
+
+import os
+import time
+import types
+from pathlib import Path
+
+import numpy as np
+import soundfile as sf
+import tritonclient
+import tritonclient.grpc.aio as grpcclient_aio  # Renamed original import
+import tritonclient.grpc as grpcclient_sync  # Added sync client import
+from tritonclient.utils import np_to_triton_dtype, InferenceServerException  # Added InferenceServerException
+
+
+# --- Added UserData and callback ---
+class UserData:
+    def __init__(self):
+        self._completed_requests = queue.Queue()
+        self._first_chunk_time = None
+        self._start_time = None
+
+    def record_start_time(self):
+        self._start_time = time.time()
+
+    def get_first_chunk_latency(self):
+        if self._first_chunk_time and self._start_time:
+            return self._first_chunk_time - self._start_time
+        return None
+
+
+def callback(user_data, result, error):
+    if user_data._first_chunk_time is None and not error:
+        user_data._first_chunk_time = time.time()  # Record time of first successful chunk
+    if error:
+        user_data._completed_requests.put(error)
+    else:
+        user_data._completed_requests.put(result)
+# --- End Added UserData and callback ---
+
+
+def write_triton_stats(stats, summary_file):
+    with open(summary_file, "w") as summary_f:
+        model_stats = stats["model_stats"]
+        # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
+        summary_f.write(
+            "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
+        )
+        summary_f.write("To learn more about the log, please refer to: \n")
+        summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
+        summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
+        summary_f.write(
+            "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
+        )
+        summary_f.write(
+            "However, there is a trade-off between the increased queue time and the increased batch size. \n"
+        )
+        summary_f.write(
+            "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
+        )
+        summary_f.write(
+            "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
+        )
+        for model_state in model_stats:
+            if "last_inference" not in model_state:
+                continue
+            summary_f.write(f"model name is {model_state['name']} \n")
+            model_inference_stats = model_state["inference_stats"]
+            total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
+            total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
+            total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
+            total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
+            summary_f.write(
+                f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"  # noqa
+            )
+            model_batch_stats = model_state["batch_stats"]
+            for batch in model_batch_stats:
+                batch_size = int(batch["batch_size"])
+                compute_input = batch["compute_input"]
+                compute_output = batch["compute_output"]
+                compute_infer = batch["compute_infer"]
+                batch_count = int(compute_infer["count"])
+                assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
+                compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
+                compute_input_time_ms = int(compute_input["ns"]) / 1e6
+                compute_output_time_ms = int(compute_output["ns"]) / 1e6
+                summary_f.write(
+                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"  # noqa
+                )
+                summary_f.write(
+                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "  # noqa
+                )
+                summary_f.write(
+                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"  # noqa
+                )
+
+
+def get_args():
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+    parser.add_argument(
+        "--server-addr",
+        type=str,
+        default="localhost",
+        help="Address of the server",
+    )
+
+    parser.add_argument(
+        "--server-port",
+        type=int,
+        default=8001,
+        help="Grpc port of the triton server, default is 8001",
+    )
+
+    parser.add_argument(
+        "--reference-audio",
+        type=str,
+        default=None,
+        help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
+    )
+
+    parser.add_argument(
+        "--reference-text",
+        type=str,
+        default="",
+        help="",
+    )
+
+    parser.add_argument(
+        "--target-text",
+        type=str,
+        default="",
+        help="",
+    )
+
+    parser.add_argument(
+        "--huggingface-dataset",
+        type=str,
+        default="yuekai/seed_tts",
+        help="dataset name in huggingface dataset hub",
+    )
+
+    parser.add_argument(
+        "--split-name",
+        type=str,
+        default="wenetspeech4tts",
+        choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
+        help="dataset split name, default is 'test'",
+    )
+
+    parser.add_argument(
+        "--manifest-path",
+        type=str,
+        default=None,
+        help="Path to the manifest dir which includes wav.scp trans.txt files.",
+    )
+
+    parser.add_argument(
+        "--model-name",
+        type=str,
+        default="f5_tts",
+        choices=[
+            "f5_tts",
+            "spark_tts",
+            "cosyvoice2"],
+        help="triton model_repo module name to request",
+    )
+
+    parser.add_argument(
+        "--num-tasks",
+        type=int,
+        default=1,
+        help="Number of concurrent tasks for sending",
+    )
+
+    parser.add_argument(
+        "--log-interval",
+        type=int,
+        default=5,
+        help="Controls how frequently we print the log.",
+    )
+
+    parser.add_argument(
+        "--compute-wer",
+        action="store_true",
+        default=False,
+        help="""True to compute WER.
+        """,
+    )
+
+    parser.add_argument(
+        "--log-dir",
+        type=str,
+        required=False,
+        default="./tmp",
+        help="log directory",
+    )
+
+    # --- Added arguments ---
+    parser.add_argument(
+        "--mode",
+        type=str,
+        default="offline",
+        choices=["offline", "streaming"],
+        help="Select offline or streaming benchmark mode."
+    )
+    parser.add_argument(
+        "--chunk-overlap-duration",
+        type=float,
+        default=0.1,
+        help="Chunk overlap duration for streaming reconstruction (in seconds)."
+    )
+    # --- End Added arguments ---
+
+    return parser.parse_args()
+
+
+def load_audio(wav_path, target_sample_rate=16000):
+    assert target_sample_rate == 16000, "hard coding in server"
+    if isinstance(wav_path, dict):
+        waveform = wav_path["array"]
+        sample_rate = wav_path["sampling_rate"]
+    else:
+        waveform, sample_rate = sf.read(wav_path)
+    if sample_rate != target_sample_rate:
+        from scipy.signal import resample
+
+        num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
+        waveform = resample(waveform, num_samples)
+    return waveform, target_sample_rate
+
+
+def prepare_request_input_output(
+    protocol_client,  # Can be grpcclient_aio or grpcclient_sync
+    waveform,
+    reference_text,
+    target_text,
+    sample_rate=16000,
+    padding_duration: int = None  # Optional padding for offline mode
+):
+    """Prepares inputs for Triton inference (offline or streaming)."""
+    assert len(waveform.shape) == 1, "waveform should be 1D"
+    lengths = np.array([[len(waveform)]], dtype=np.int32)
+
+    # Apply padding only if padding_duration is provided (for offline)
+    if padding_duration:
+        duration = len(waveform) / sample_rate
+        # Estimate target duration based on text length ratio (crude estimation)
+        # Avoid division by zero if reference_text is empty
+        if reference_text:
+            estimated_target_duration = duration / len(reference_text) * len(target_text)
+        else:
+            estimated_target_duration = duration  # Assume target duration similar to reference if no text
+
+        # Calculate required samples based on estimated total duration
+        required_total_samples = padding_duration * sample_rate * (
+            (int(estimated_target_duration + duration) // padding_duration) + 1
+        )
+        samples = np.zeros((1, required_total_samples), dtype=np.float32)
+        samples[0, : len(waveform)] = waveform
+    else:
+        # No padding for streaming or if padding_duration is None
+        samples = waveform.reshape(1, -1).astype(np.float32)
+
+    # Common input creation logic
+    inputs = [
+        protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
+        protocol_client.InferInput(
+            "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
+        ),
+        protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
+        protocol_client.InferInput("target_text", [1, 1], "BYTES"),
+    ]
+    inputs[0].set_data_from_numpy(samples)
+    inputs[1].set_data_from_numpy(lengths)
+
+    input_data_numpy = np.array([reference_text], dtype=object)
+    input_data_numpy = input_data_numpy.reshape((1, 1))
+    inputs[2].set_data_from_numpy(input_data_numpy)
+
+    input_data_numpy = np.array([target_text], dtype=object)
+    input_data_numpy = input_data_numpy.reshape((1, 1))
+    inputs[3].set_data_from_numpy(input_data_numpy)
+
+    outputs = [protocol_client.InferRequestedOutput("waveform")]
+
+    return inputs, outputs
+
+
+def run_sync_streaming_inference(
+    sync_triton_client: tritonclient.grpc.InferenceServerClient,
+    model_name: str,
+    inputs: list,
+    outputs: list,
+    request_id: str,
+    user_data: UserData,
+    chunk_overlap_duration: float,
+    save_sample_rate: int,
+    audio_save_path: str,
+):
+    """Helper function to run the blocking sync streaming call."""
+    start_time_total = time.time()
+    user_data.record_start_time()  # Record start time for first chunk latency calculation
+
+    # Establish stream
+    sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
+
+    # Send request
+    sync_triton_client.async_stream_infer(
+        model_name,
+        inputs,
+        request_id=request_id,
+        outputs=outputs,
+        enable_empty_final_response=True,
+    )
+
+    # Process results
+    audios = []
+    while True:
+        try:
+            result = user_data._completed_requests.get()  # Add timeout
+            if isinstance(result, InferenceServerException):
+                print(f"Received InferenceServerException: {result}")
+                sync_triton_client.stop_stream()
+                return None, None, None  # Indicate error
+            # Get response metadata
+            response = result.get_response()
+            final = response.parameters["triton_final_response"].bool_param
+            if final is True:
+                break
+
+            audio_chunk = result.as_numpy("waveform").reshape(-1)
+            if audio_chunk.size > 0:  # Only append non-empty chunks
+                audios.append(audio_chunk)
+            else:
+                print("Warning: received empty audio chunk.")
+
+        except queue.Empty:
+            print(f"Timeout waiting for response for request id {request_id}")
+            sync_triton_client.stop_stream()
+            return None, None, None  # Indicate error
+
+    sync_triton_client.stop_stream()
+    end_time_total = time.time()
+    total_request_latency = end_time_total - start_time_total
+    first_chunk_latency = user_data.get_first_chunk_latency()
+
+    # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
+    actual_duration = 0
+    if audios:
+        cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
+        fade_out = np.linspace(1, 0, cross_fade_samples)
+        fade_in = np.linspace(0, 1, cross_fade_samples)
+        reconstructed_audio = None
+
+        # Simplified reconstruction based on client_grpc_streaming.py
+        if not audios:
+            print("Warning: No audio chunks received.")
+            reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
+        elif len(audios) == 1:
+            reconstructed_audio = audios[0]
+        else:
+            reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
+            for i in range(1, len(audios)):
+                # Cross-fade section
+                cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
+                                       audios[i - 1][-cross_fade_samples:] * fade_out)
+                # Middle section of the current chunk
+                middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
+                # Concatenate
+                reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
+            # Add the last part of the final chunk
+            reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
+
+        if reconstructed_audio is not None and reconstructed_audio.size > 0:
+            actual_duration = len(reconstructed_audio) / save_sample_rate
+            # Save reconstructed audio
+            os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
+            sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
+        else:
+            print("Warning: No audio chunks received or reconstructed.")
+            actual_duration = 0  # Set duration to 0 if no audio
+
+    else:
+        print("Warning: No audio chunks received.")
+        actual_duration = 0
+
+    return total_request_latency, first_chunk_latency, actual_duration
+
+
+async def send_streaming(
+    manifest_item_list: list,
+    name: str,
+    server_url: str,  # Changed from sync_triton_client
+    protocol_client: types.ModuleType,
+    log_interval: int,
+    model_name: str,
+    audio_save_dir: str = "./",
+    save_sample_rate: int = 16000,
+    chunk_overlap_duration: float = 0.1,
+    padding_duration: int = None,
+):
+    total_duration = 0.0
+    latency_data = []
+    task_id = int(name[5:])
+    sync_triton_client = None  # Initialize client variable
+
+    try:  # Wrap in try...finally to ensure client closing
+        print(f"{name}: Initializing sync client for streaming...")
+        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)  # Create client here
+
+        print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
+        for i, item in enumerate(manifest_item_list):
+            if i % log_interval == 0:
+                print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
+
+            try:
+                waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
+                reference_text, target_text = item["reference_text"], item["target_text"]
+
+                inputs, outputs = prepare_request_input_output(
+                    protocol_client,
+                    waveform,
+                    reference_text,
+                    target_text,
+                    sample_rate,
+                    padding_duration=padding_duration
+                )
+                request_id = str(uuid.uuid4())
+                user_data = UserData()
+
+                audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
+
+                total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
+                    run_sync_streaming_inference,
+                    sync_triton_client,
+                    model_name,
+                    inputs,
+                    outputs,
+                    request_id,
+                    user_data,
+                    chunk_overlap_duration,
+                    save_sample_rate,
+                    audio_save_path
+                )
+
+                if total_request_latency is not None:
+                    print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
+                    latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
+                    total_duration += actual_duration
+                else:
+                    print(f"{name}: Item {i} failed.")
+
+            except FileNotFoundError:
+                print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
+            except Exception as e:
+                print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
+                import traceback
+                traceback.print_exc()
+
+    finally:  # Ensure client is closed
+        if sync_triton_client:
+            try:
+                print(f"{name}: Closing sync client...")
+                sync_triton_client.close()
+            except Exception as e:
+                print(f"{name}: Error closing sync client: {e}")
+
+    print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
+    return total_duration, latency_data
+
+
+async def send(
+    manifest_item_list: list,
+    name: str,
+    triton_client: tritonclient.grpc.aio.InferenceServerClient,
+    protocol_client: types.ModuleType,
+    log_interval: int,
+    model_name: str,
+    padding_duration: int = None,
+    audio_save_dir: str = "./",
+    save_sample_rate: int = 16000,
+):
+    total_duration = 0.0
+    latency_data = []
+    task_id = int(name[5:])
+
+    print(f"manifest_item_list: {manifest_item_list}")
+    for i, item in enumerate(manifest_item_list):
+        if i % log_interval == 0:
+            print(f"{name}: {i}/{len(manifest_item_list)}")
+        waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
+        reference_text, target_text = item["reference_text"], item["target_text"]
+
+        inputs, outputs = prepare_request_input_output(
+            protocol_client,
+            waveform,
+            reference_text,
+            target_text,
+            sample_rate,
+            padding_duration=padding_duration
+        )
+        sequence_id = 100000000 + i + task_id * 10
+        start = time.time()
+        response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
+
+        audio = response.as_numpy("waveform").reshape(-1)
+        actual_duration = len(audio) / save_sample_rate
+
+        end = time.time() - start
+
+        audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
+        sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
+
+        latency_data.append((end, actual_duration))
+        total_duration += actual_duration
+
+    return total_duration, latency_data
+
+
+def load_manifests(manifest_path):
+    with open(manifest_path, "r") as f:
+        manifest_list = []
+        for line in f:
+            assert len(line.strip().split("|")) == 4
+            utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
+            utt = Path(utt).stem
+            # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
+            if not os.path.isabs(prompt_wav):
+                prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
+            manifest_list.append(
+                {
+                    "audio_filepath": prompt_wav,
+                    "reference_text": prompt_text,
+                    "target_text": gt_text,
+                    "target_audio_path": utt,
+                }
+            )
+    return manifest_list
+
+
+def split_data(data, k):
+    n = len(data)
+    if n < k:
+        print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
+        k = n
+
+    quotient = n // k
+    remainder = n % k
+
+    result = []
+    start = 0
+    for i in range(k):
+        if i < remainder:
+            end = start + quotient + 1
+        else:
+            end = start + quotient
+
+        result.append(data[start:end])
+        start = end
+
+    return result
+
+
+async def main():
+    args = get_args()
+    url = f"{args.server_addr}:{args.server_port}"
+
+    # --- Client Initialization based on mode ---
+    triton_client = None
+    protocol_client = None
+    if args.mode == "offline":
+        print("Initializing gRPC client for offline mode...")
+        # Use the async client for offline tasks
+        triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
+        protocol_client = grpcclient_aio
+    elif args.mode == "streaming":
+        print("Initializing gRPC client for streaming mode...")
+        # Use the sync client for streaming tasks, handled via asyncio.to_thread
+        # We will create one sync client instance PER TASK inside send_streaming.
+        # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
+        protocol_client = grpcclient_sync  # protocol client for input prep
+    else:
+        raise ValueError(f"Invalid mode: {args.mode}")
+    # --- End Client Initialization ---
+
+    if args.reference_audio:
+        args.num_tasks = 1
+        args.log_interval = 1
+        manifest_item_list = [
+            {
+                "reference_text": args.reference_text,
+                "target_text": args.target_text,
+                "audio_filepath": args.reference_audio,
+                "target_audio_path": "test",
+            }
+        ]
+    elif args.huggingface_dataset:
+        import datasets
+
+        dataset = datasets.load_dataset(
+            args.huggingface_dataset,
+            split=args.split_name,
+            trust_remote_code=True,
+        )
+        manifest_item_list = []
+        for i in range(len(dataset)):
+            manifest_item_list.append(
+                {
+                    "audio_filepath": dataset[i]["prompt_audio"],
+                    "reference_text": dataset[i]["prompt_text"],
+                    "target_audio_path": dataset[i]["id"],
+                    "target_text": dataset[i]["target_text"],
+                }
+            )
+    else:
+        manifest_item_list = load_manifests(args.manifest_path)
+
+    num_tasks = min(args.num_tasks, len(manifest_item_list))
+    manifest_item_list = split_data(manifest_item_list, num_tasks)
+
+    os.makedirs(args.log_dir, exist_ok=True)
+    tasks = []
+    start_time = time.time()
+    for i in range(num_tasks):
+        # --- Task Creation based on mode ---
+        if args.mode == "offline":
+            task = asyncio.create_task(
+                send(
+                    manifest_item_list[i],
+                    name=f"task-{i}",
+                    triton_client=triton_client,
+                    protocol_client=protocol_client,
+                    log_interval=args.log_interval,
+                    model_name=args.model_name,
+                    audio_save_dir=args.log_dir,
+                    padding_duration=1,
+                    save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
+                )
+            )
+        elif args.mode == "streaming":
+            task = asyncio.create_task(
+                send_streaming(
+                    manifest_item_list[i],
+                    name=f"task-{i}",
+                    server_url=url,  # Pass URL instead of client
+                    protocol_client=protocol_client,
+                    log_interval=args.log_interval,
+                    model_name=args.model_name,
+                    audio_save_dir=args.log_dir,
+                    padding_duration=10,
+                    save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
+                    chunk_overlap_duration=args.chunk_overlap_duration,
+                )
+            )
+        # --- End Task Creation ---
+        tasks.append(task)
+
+    ans_list = await asyncio.gather(*tasks)
+
+    end_time = time.time()
+    elapsed = end_time - start_time
+
+    total_duration = 0.0
+    latency_data = []
+    for ans in ans_list:
+        if ans:
+            total_duration += ans[0]
+            latency_data.extend(ans[1])  # Use extend for list of lists
+        else:
+            print("Warning: A task returned None, possibly due to an error.")
+
+    if total_duration == 0:
+        print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
+        rtf = float('inf')
+    else:
+        rtf = elapsed / total_duration
+
+    s = f"Mode: {args.mode}\n"
+    s += f"RTF: {rtf:.4f}\n"
+    s += f"total_duration: {total_duration:.3f} seconds\n"
+    s += f"({total_duration / 3600:.2f} hours)\n"
+    s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
+
+    # --- Statistics Reporting based on mode ---
+    if latency_data:
+        if args.mode == "offline":
+            # Original offline latency calculation
+            latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
+            if latency_list:
+                latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
+                latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
+                s += f"latency_variance: {latency_variance:.2f}\n"
+                s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
+                s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
+                s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
+                s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
+                s += f"average_latency_ms: {latency_ms:.2f}\n"
+            else:
+                s += "No latency data collected for offline mode.\n"
+
+        elif args.mode == "streaming":
+            # Calculate stats for total request latency and first chunk latency
+            total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
+            first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
+
+            s += "\n--- Total Request Latency ---\n"
+            if total_latency_list:
+                avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
+                variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
+                s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
+                s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
+                s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
+                s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
+                s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
+                s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
+            else:
+                s += "No total request latency data collected.\n"
+
+            s += "\n--- First Chunk Latency ---\n"
+            if first_chunk_latency_list:
+                avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
+                variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
+                s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
+                s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
+                s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
+                s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
+                s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
+                s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
+            else:
+                s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
+    else:
+        s += "No latency data collected.\n"
+    # --- End Statistics Reporting ---
+
+    print(s)
+    if args.manifest_path:
+        name = Path(args.manifest_path).stem
+    elif args.split_name:
+        name = args.split_name
+    elif args.reference_audio:
+        name = Path(args.reference_audio).stem
+    else:
+        name = "results"  # Default name if no manifest/split/audio provided
+    with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
+        f.write(s)
+
+    # --- Statistics Fetching using temporary Async Client ---
+    # Use a separate async client for fetching stats regardless of mode
+    stats_client = None
+    try:
+        print("Initializing temporary async client for fetching stats...")
+        stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
+        print("Fetching inference statistics...")
+        # Fetching for all models, filtering might be needed depending on server setup
+        stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
+        print("Fetching model config...")
+        metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
+
+        write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
+
+        with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
+            json.dump(metadata, f, indent=4)
+
+    except Exception as e:
+        print(f"Could not retrieve statistics or config: {e}")
+    finally:
+        if stats_client:
+            try:
+                print("Closing temporary async stats client...")
+                await stats_client.close()
+            except Exception as e:
+                print(f"Error closing async stats client: {e}")
+    # --- End Statistics Fetching ---
+
+
+if __name__ == "__main__":
+    # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
+    async def run_main():
+        try:
+            await main()
+        except Exception as e:
+            print(f"An error occurred in main: {e}")
+            import traceback
+            traceback.print_exc()
+
+    asyncio.run(run_main())

+ 173 - 0
runtime/triton_trtllm/client_http.py

@@ -0,0 +1,173 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import requests
+import soundfile as sf
+import json
+import numpy as np
+import argparse
+
+
+def get_args():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--server-url",
+        type=str,
+        default="localhost:8000",
+        help="Address of the server",
+    )
+
+    parser.add_argument(
+        "--reference-audio",
+        type=str,
+        default="../../example/prompt_audio.wav",
+        help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
+    )
+
+    parser.add_argument(
+        "--reference-text",
+        type=str,
+        default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
+        help="",
+    )
+
+    parser.add_argument(
+        "--target-text",
+        type=str,
+        default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
+        help="",
+    )
+
+    parser.add_argument(
+        "--model-name",
+        type=str,
+        default="spark_tts",
+        choices=[
+            "f5_tts",
+            "spark_tts",
+            "cosyvoice2"],
+        help="triton model_repo module name to request",
+    )
+
+    parser.add_argument(
+        "--output-audio",
+        type=str,
+        default="output.wav",
+        help="Path to save the output audio",
+    )
+    return parser.parse_args()
+
+
+def prepare_request(
+    waveform,
+    reference_text,
+    target_text,
+    sample_rate=16000,
+    padding_duration: int = None,
+    audio_save_dir: str = "./",
+):
+    assert len(waveform.shape) == 1, "waveform should be 1D"
+    lengths = np.array([[len(waveform)]], dtype=np.int32)
+    if padding_duration:
+        # padding to nearset 10 seconds
+        samples = np.zeros(
+            (
+                1,
+                padding_duration
+                * sample_rate
+                * ((int(len(waveform) / sample_rate) // padding_duration) + 1),
+            ),
+            dtype=np.float32,
+        )
+
+        samples[0, : len(waveform)] = waveform
+    else:
+        samples = waveform
+
+    samples = samples.reshape(1, -1).astype(np.float32)
+
+    data = {
+        "inputs": [
+            {
+                "name": "reference_wav",
+                "shape": samples.shape,
+                "datatype": "FP32",
+                "data": samples.tolist()
+            },
+            {
+                "name": "reference_wav_len",
+                "shape": lengths.shape,
+                "datatype": "INT32",
+                "data": lengths.tolist(),
+            },
+            {
+                "name": "reference_text",
+                "shape": [1, 1],
+                "datatype": "BYTES",
+                "data": [reference_text]
+            },
+            {
+                "name": "target_text",
+                "shape": [1, 1],
+                "datatype": "BYTES",
+                "data": [target_text]
+            }
+        ]
+    }
+
+    return data
+
+
+if __name__ == "__main__":
+    args = get_args()
+    server_url = args.server_url
+    if not server_url.startswith(("http://", "https://")):
+        server_url = f"http://{server_url}"
+
+    url = f"{server_url}/v2/models/{args.model_name}/infer"
+    waveform, sr = sf.read(args.reference_audio)
+    assert sr == 16000, "sample rate hardcoded in server"
+
+    samples = np.array(waveform, dtype=np.float32)
+    data = prepare_request(samples, args.reference_text, args.target_text)
+
+    rsp = requests.post(
+        url,
+        headers={"Content-Type": "application/json"},
+        json=data,
+        verify=False,
+        params={"request_id": '0'}
+    )
+    result = rsp.json()
+    audio = result["outputs"][0]["data"]
+    audio = np.array(audio, dtype=np.float32)
+    if args.model_name == "spark_tts":
+        sample_rate = 16000
+    else:
+        sample_rate = 24000
+    sf.write(args.output_audio, audio, sample_rate, "PCM_16")

+ 20 - 0
runtime/triton_trtllm/docker-compose.yml

@@ -0,0 +1,20 @@
+services:
+  tts:
+    image: soar97/triton-cosyvoice:25.06
+    shm_size: '1gb'
+    ports:
+      - "8000:8000"
+      - "8001:8001"
+      - "8002:8002"
+    environment:
+      - PYTHONIOENCODING=utf-8
+      - MODEL_ID=${MODEL_ID}
+    deploy:
+      resources:
+        reservations:
+          devices:
+            - driver: nvidia
+              device_ids: ['0']
+              capabilities: [gpu]
+    command: >
+      /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/FunAudioLLM/CosyVoice.git && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3"

+ 97 - 0
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py

@@ -0,0 +1,97 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import json
+import torch
+from torch.utils.dlpack import to_dlpack
+
+import triton_python_backend_utils as pb_utils
+
+import os
+import numpy as np
+import s3tokenizer
+
+ORIGINAL_VOCAB_SIZE = 151663
+
+
+class TritonPythonModel:
+    """Triton Python model for audio tokenization.
+
+    This model takes reference audio input and extracts semantic tokens
+    using s3tokenizer.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        # Parse model parameters
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+
+        self.device = torch.device("cuda")
+        model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
+        self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
+
+    def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing tokenized outputs
+        """
+        mels = []
+
+        # Process each request in batch
+        for request in requests:
+            # Extract input tensors
+            wav_array = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav").as_numpy()
+            wav_len = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav_len").as_numpy().item()
+
+            wav_array = torch.from_numpy(wav_array).to(self.device)
+            # Prepare inputs
+            wav = wav_array[:, :wav_len].squeeze(0)
+            mels.append(s3tokenizer.log_mel_spectrogram(wav))
+
+        mels, mels_lens = s3tokenizer.padding(mels)
+        codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
+        codes = codes.clone() + ORIGINAL_VOCAB_SIZE
+
+        responses = []
+        for i in range(len(requests)):
+            prompt_speech_tokens = codes[i, :codes_lens[i].item()]
+            prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
+                "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
+            inference_response = pb_utils.InferenceResponse(
+                output_tensors=[prompt_speech_tokens_tensor])
+            responses.append(inference_response)
+
+        return responses

+ 53 - 0
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt

@@ -0,0 +1,53 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "audio_tokenizer"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+parameters [
+  {
+   key: "model_dir", 
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+  }
+]
+output [
+  {
+    name: "prompt_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 346 - 0
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -0,0 +1,346 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import math
+import os
+import re
+from typing import Dict, List, Tuple, Optional, Union
+
+import numpy as np
+import torch
+from torch.utils.dlpack import from_dlpack, to_dlpack
+import triton_python_backend_utils as pb_utils
+from transformers import AutoTokenizer
+import torchaudio.compliance.kaldi as kaldi
+import torchaudio
+import onnxruntime
+
+
+from matcha.utils.audio import mel_spectrogram
+
+
+class TritonPythonModel:
+    """Triton Python model for Spark TTS.
+
+    This model orchestrates the end-to-end TTS pipeline by coordinating
+    between audio tokenizer, LLM, and vocoder components.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        self.logger = pb_utils.Logger
+        # Parse model parameters
+        self.model_config = json.loads(args['model_config'])
+        parameters = self.model_config['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+        self.logger.log_info(f"model_params:{model_params}")
+
+        # Initialize tokenizer
+        llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
+        self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
+        self.prompt_template = "<|sos|>{input_text}<|task_id|>"
+        self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
+
+        self.device = torch.device("cuda")
+        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
+
+        campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
+
+    def forward_llm(self, input_ids):
+        """
+        Prepares the response from the language model based on the provided
+        inputs. Creates a `pb_utils.InferenceRequest` object with passed
+        `llm_request_inputs` to send to a decoupled TensorRTLLM model.
+        For each response from the language model:
+            - Checks for errors and raise an exception if any are found.
+            - Extracts the "output_ids" tensor from the response.
+            - Determines the finish reason based on the presence of the
+              end-of-sequence token or reaching the maximum length.
+            - Appends the generated token IDs to `output_ids`.
+            - If the finish reason is determined, decodes the output IDs to text
+              and prepares the final response.
+
+        The final response includes the generated text, finish reason,
+        completion tokens, prompt tokens, and total tokens.
+
+        Parameters
+        ----------
+        - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
+
+        Returns
+        -------
+        - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
+        """
+        # convert input_ids to numpy, with shape [1, sequence_length]
+        input_ids = input_ids.cpu().numpy()
+        max_tokens = 1024
+        input_dict = {
+            "request_output_len": np.array([[max_tokens]], dtype=np.int32),
+            "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
+            "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
+            "streaming": np.array([[self.decoupled]], dtype=np.bool_),
+            "runtime_top_p": np.array([[0.95]], dtype=np.float32),
+            "runtime_top_k": np.array([[50]], dtype=np.int32),
+            "temperature": np.array([[0.8]], dtype=np.float32),
+            "input_ids": input_ids,
+            "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
+        }
+
+        # Convert inputs to Triton tensors
+        input_tensor_list = [
+            pb_utils.Tensor(k, v) for k, v in input_dict.items()
+        ]
+
+        # Create and execute inference request
+        llm_request = pb_utils.InferenceRequest(
+            model_name="tensorrt_llm",
+            requested_output_names=["output_ids", "sequence_length"],
+            inputs=input_tensor_list,
+        )
+
+        llm_responses = llm_request.exec(decoupled=self.decoupled)
+        if self.decoupled:
+            for llm_response in llm_responses:
+                if llm_response.has_error():
+                    raise pb_utils.TritonModelException(llm_response.error().message())
+
+                # Extract and process output
+                output_ids = pb_utils.get_output_tensor_by_name(
+                    llm_response, "output_ids").as_numpy()
+                seq_lens = pb_utils.get_output_tensor_by_name(
+                    llm_response, "sequence_length").as_numpy()
+
+                # Get actual output IDs up to the sequence length
+                actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
+
+                yield actual_output_ids
+        else:
+            llm_response = llm_responses
+            if llm_response.has_error():
+                raise pb_utils.TritonModelException(llm_response.error().message())
+
+            # Extract and process output
+            output_ids = pb_utils.get_output_tensor_by_name(
+                llm_response, "output_ids").as_numpy()
+            seq_lens = pb_utils.get_output_tensor_by_name(
+                llm_response, "sequence_length").as_numpy()
+
+            # Get actual output IDs up to the sequence length
+            actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
+
+            yield actual_output_ids
+
+    def forward_audio_tokenizer(self, wav, wav_len):
+        """Forward pass through the audio tokenizer component.
+
+        Args:
+            wav: Input waveform tensor
+            wav_len: Waveform length tensor
+
+        Returns:
+            Tuple of global and semantic tokens
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='audio_tokenizer',
+            requested_output_names=['prompt_speech_tokens'],
+            inputs=[wav, wav_len]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
+        prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
+
+        return prompt_speech_tokens
+
+    def forward_token2wav(
+            self,
+            prompt_speech_tokens: torch.Tensor,
+            prompt_speech_feat: torch.Tensor,
+            prompt_spk_embedding: torch.Tensor,
+            target_speech_tokens: torch.Tensor) -> torch.Tensor:
+        """Forward pass through the vocoder component.
+
+        Args:
+            prompt_speech_tokens: Prompt speech tokens tensor
+            prompt_speech_feat: Prompt speech feat tensor
+            prompt_spk_embedding: Prompt spk embedding tensor
+            target_speech_tokens: Target speech tokens tensor
+
+        Returns:
+            Generated waveform tensor
+        """
+        prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
+        prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
+        prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
+        target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
+
+        # Create and execute inference request
+        inference_request = pb_utils.InferenceRequest(
+            model_name='token2wav',
+            requested_output_names=['waveform'],
+            inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output waveform
+        waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
+        waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
+
+        return waveform
+
+    def parse_input(self, text, prompt_text, prompt_speech_tokens):
+        total_text = f"{prompt_text}{text}"
+        prompt = self.prompt_template.format(input_text=total_text)
+        input_ids = self.tokenizer.encode(prompt)
+        input_ids = torch.tensor([input_ids], dtype=torch.int32)
+        input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
+        return input_ids
+
+    def _extract_spk_embedding(self, speech):
+        feat = kaldi.fbank(speech,
+                           num_mel_bins=80,
+                           dither=0,
+                           sample_frequency=16000)
+        feat = feat - feat.mean(dim=0, keepdim=True)
+        embedding = self.campplus_session.run(None,
+                                              {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+        embedding = torch.tensor([embedding]).to(self.device).half()
+        return embedding
+
+    def _extract_speech_feat(self, speech):
+        speech_feat = mel_spectrogram(
+            speech,
+            n_fft=1920,
+            num_mels=80,
+            sampling_rate=24000,
+            hop_size=480,
+            win_size=1920,
+            fmin=0,
+            fmax=8000).squeeze(
+            dim=0).transpose(
+            0,
+            1).to(
+                self.device)
+        speech_feat = speech_feat.unsqueeze(dim=0)
+        return speech_feat
+
+    def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated audio
+        """
+        responses = []
+
+        for request in requests:
+            # Extract input tensors
+            wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
+            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+
+            # Process reference audio through audio tokenizer
+
+            prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
+            prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
+
+            wav_tensor = wav.as_numpy()
+            wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
+            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
+            speech_feat = self._extract_speech_feat(prompt_speech_resample)
+            token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
+            prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
+            prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
+
+            reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+            reference_text = reference_text[0][0].decode('utf-8')
+
+            target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
+            target_text = target_text[0][0].decode('utf-8')
+
+            # Prepare prompt for LLM
+            input_ids = self.parse_input(
+                text=target_text,
+                prompt_text=reference_text,
+                prompt_speech_tokens=prompt_speech_tokens,
+            )
+
+            # Generate semantic tokens with LLM
+            generated_ids_iter = self.forward_llm(input_ids)
+
+            if self.decoupled:
+                response_sender = request.get_response_sender()
+                request_id = request.request_id()
+                generated_ids = []
+                for generated_id in generated_ids_iter:
+                    # convert the numpy array into a int32 tensor
+                    generated_id = generated_id.tolist()
+                    if len(generated_id) > 0:
+                        assert len(generated_id) == 1, "Generated ID is not a single integer"
+                        generated_ids.append(generated_id[0])
+                generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
+                prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
+                audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
+
+                # Prepare response
+                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
+                inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                response_sender.send(inference_response)
+                response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+                self.logger.log_info("send tritonserver_response_complete_final to end")
+            else:
+                generated_ids = next(generated_ids_iter)
+                generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
+                if generated_ids is None or len(generated_ids) == 0:
+                    raise pb_utils.TritonModelException("Generated IDs is None or empty")
+
+                prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
+                audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
+
+                # Prepare response
+                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
+                inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                responses.append(inference_response)
+
+        if not self.decoupled:
+            return responses

+ 70 - 0
runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt

@@ -0,0 +1,70 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "cosyvoice2"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+model_transaction_policy {
+  decoupled: ${decoupled_mode}
+}
+parameters [
+  {
+   key: "llm_tokenizer_dir", 
+   value: {string_value:"${llm_tokenizer_dir}"}
+  },
+  {
+   key: "model_dir", 
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+  },
+  {
+    name: "reference_text"
+    data_type: TYPE_STRING
+    dims: [1]
+  },
+  {
+    name: "target_text"
+    data_type: TYPE_STRING
+    dims: [1]
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: ${bls_instance_num}
+    kind: KIND_CPU
+  }
+]

+ 0 - 0
runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep


+ 857 - 0
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt

@@ -0,0 +1,857 @@
+# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "tensorrt_llm"
+backend: "${triton_backend}"
+max_batch_size: ${triton_max_batch_size}
+
+model_transaction_policy {
+  decoupled: ${decoupled_mode}
+}
+
+dynamic_batching {
+    preferred_batch_size: [ ${triton_max_batch_size} ]
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+    default_queue_policy: { max_queue_size: ${max_queue_size} }
+}
+
+input [
+  {
+    name: "input_ids"
+    data_type: TYPE_INT32
+    dims: [ -1 ]
+    allow_ragged_batch: true
+    optional: true
+  },
+  {
+    name: "encoder_input_features"
+    data_type: ${encoder_input_features_data_type}
+    dims: [ -1, -1 ]
+    allow_ragged_batch: true
+    optional: true
+  },
+  {
+    name: "encoder_output_lengths"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "input_lengths"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+  },
+  {
+    name: "request_output_len"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+  },
+  {
+    name: "num_return_sequences"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "draft_input_ids"
+    data_type: TYPE_INT32
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "decoder_input_ids"
+    data_type: TYPE_INT32
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "decoder_input_lengths"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    optional: true
+    reshape: { shape: [ ] }
+  },
+  {
+    name: "draft_logits"
+    data_type: ${logits_datatype}
+    dims: [ -1, -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "draft_acceptance_threshold"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "end_id"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "pad_id"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "stop_words_list"
+    data_type: TYPE_INT32
+    dims: [ 2, -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "bad_words_list"
+    data_type: TYPE_INT32
+    dims: [ 2, -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "embedding_bias"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "beam_width"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "temperature"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "runtime_top_k"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "runtime_top_p"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "runtime_top_p_min"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "runtime_top_p_decay"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "runtime_top_p_reset_ids"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "len_penalty"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "early_stopping"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "repetition_penalty"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "min_length"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "beam_search_diversity_rate"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "presence_penalty"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "frequency_penalty"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "random_seed"
+    data_type: TYPE_UINT64
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "return_log_probs"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "return_context_logits"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "return_generation_logits"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "return_perf_metrics"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "exclude_input_in_output"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "stop"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "streaming"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "prompt_embedding_table"
+    data_type: TYPE_FP16
+    dims: [ -1, -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "prompt_table_extra_ids"
+    data_type: TYPE_UINT64
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "prompt_vocab_size"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
+  {
+    name: "cross_attention_mask"
+    data_type: TYPE_BOOL
+    dims: [ -1, -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  # Mrope param when mrope is used
+  {
+    name: "mrope_rotary_cos_sin"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+    optional: true
+  },
+  {
+    name: "mrope_position_deltas"
+    data_type: TYPE_INT64
+    dims: [ 1 ]
+    optional: true
+  },
+  # the unique task ID for the given LoRA.
+  # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
+  # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
+  # If the cache is full the oldest LoRA will be evicted to make space for new ones.  An error is returned if `lora_task_id` is not cached.
+  {
+    name: "lora_task_id"
+    data_type: TYPE_UINT64
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
+  # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
+  # each of the in / out tensors are first flattened and then concatenated together in the format above.
+  # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
+  {
+    name: "lora_weights"
+    data_type: TYPE_FP16
+    dims: [ -1, -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  # module identifier (same size a first dimension of lora_weights)
+  # See LoraModule::ModuleType for model id mapping
+  #
+  # "attn_qkv": 0     # compbined qkv adapter
+  # "attn_q": 1       # q adapter
+  # "attn_k": 2       # k adapter
+  # "attn_v": 3       # v adapter
+  # "attn_dense": 4   # adapter for the dense layer in attention
+  # "mlp_h_to_4h": 5  # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
+  # "mlp_4h_to_h": 6  # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
+  # "mlp_gate": 7     # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
+  #
+  # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
+  {
+    name: "lora_config"
+    data_type: TYPE_INT32
+    dims: [ -1, 3 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "context_phase_params"
+    data_type: TYPE_UINT8
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
+  {
+    name: "skip_cross_attn_blocks"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "retention_token_range_starts"
+    data_type: TYPE_INT32
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "retention_token_range_ends"
+    data_type: TYPE_INT32
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "retention_token_range_priorities"
+    data_type: TYPE_INT32
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "retention_token_range_durations_ms"
+    data_type: TYPE_INT32
+    dims: [ -1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "retention_decode_priority"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "retention_decode_duration_ms"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "guided_decoding_guide_type"
+    data_type: TYPE_STRING
+    dims: [ 1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "guided_decoding_guide"
+    data_type: TYPE_STRING
+    dims: [ 1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "lookahead_window_size"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "lookahead_ngram_size"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    optional: true
+    allow_ragged_batch: true
+  },
+  {
+    name: "lookahead_verification_set_size"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    optional: true
+    allow_ragged_batch: true
+  }
+]
+output [
+  {
+    name: "output_ids"
+    data_type: TYPE_INT32
+    dims: [ -1, -1 ]
+  },
+  {
+    name: "sequence_length"
+    data_type: TYPE_INT32
+    dims: [ -1 ]
+  },
+  {
+    name: "cum_log_probs"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  },
+  {
+    name: "output_log_probs"
+    data_type: TYPE_FP32
+    dims: [ -1, -1 ]
+  },
+  {
+    name: "context_logits"
+    data_type: ${logits_datatype}
+    dims: [ -1, -1 ]
+  },
+  {
+    name: "generation_logits"
+    data_type: ${logits_datatype}
+    dims: [ -1, -1, -1 ]
+  },
+  {
+    name: "batch_index"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  },
+  {
+    name: "sequence_index"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  },
+  {
+    name: "context_phase_params"
+    data_type: TYPE_UINT8
+    dims: [ -1 ]
+  },
+  {
+    name: "kv_cache_alloc_new_blocks"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  },
+  {
+    name: "kv_cache_reused_blocks"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  },
+  {
+    name: "kv_cache_alloc_total_blocks"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  },
+  {
+    name: "arrival_time_ns"
+    data_type: TYPE_INT64
+    dims: [ 1 ]
+  },
+  {
+    name: "first_scheduled_time_ns"
+    data_type: TYPE_INT64
+    dims: [ 1 ]
+  },
+  {
+    name: "first_token_time_ns"
+    data_type: TYPE_INT64
+    dims: [ 1 ]
+  },
+  {
+    name: "last_token_time_ns"
+    data_type: TYPE_INT64
+    dims: [ 1 ]
+  },
+  {
+    name: "acceptance_rate"
+    data_type: TYPE_FP32
+    dims: [ 1 ]
+  },
+  {
+    name: "total_accepted_draft_tokens"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  },
+  {
+    name: "total_draft_tokens"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+  }
+]
+instance_group [
+  {
+    count: 1
+    kind : KIND_CPU
+  }
+]
+parameters: {
+  key: "max_beam_width"
+  value: {
+    string_value: "${max_beam_width}"
+  }
+}
+parameters: {
+  key: "FORCE_CPU_ONLY_INPUT_TENSORS"
+  value: {
+    string_value: "no"
+  }
+}
+parameters: {
+  key: "gpt_model_type"
+  value: {
+    string_value: "${batching_strategy}"
+  }
+}
+parameters: {
+  key: "gpt_model_path"
+  value: {
+    string_value: "${engine_dir}"
+  }
+}
+parameters: {
+  key: "encoder_model_path"
+  value: {
+    string_value: "${encoder_engine_dir}"
+  }
+}
+parameters: {
+  key: "max_tokens_in_paged_kv_cache"
+  value: {
+    string_value: "${max_tokens_in_paged_kv_cache}"
+  }
+}
+parameters: {
+  key: "max_attention_window_size"
+  value: {
+    string_value: "${max_attention_window_size}"
+  }
+}
+parameters: {
+  key: "sink_token_length"
+  value: {
+    string_value: "${sink_token_length}"
+  }
+}
+parameters: {
+  key: "batch_scheduler_policy"
+  value: {
+    string_value: "${batch_scheduler_policy}"
+  }
+}
+parameters: {
+  key: "kv_cache_free_gpu_mem_fraction"
+  value: {
+    string_value: "${kv_cache_free_gpu_mem_fraction}"
+  }
+}
+parameters: {
+  key: "cross_kv_cache_fraction"
+  value: {
+    string_value: "${cross_kv_cache_fraction}"
+  }
+}
+parameters: {
+  key: "kv_cache_host_memory_bytes"
+  value: {
+    string_value: "${kv_cache_host_memory_bytes}"
+  }
+}
+# kv_cache_onboard_blocks is for internal implementation.
+parameters: {
+  key: "kv_cache_onboard_blocks"
+  value: {
+    string_value: "${kv_cache_onboard_blocks}"
+  }
+}
+# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
+# parameters: {
+#   key: "enable_trt_overlap"
+#   value: {
+#     string_value: "${enable_trt_overlap}"
+#   }
+# }
+parameters: {
+  key: "exclude_input_in_output"
+  value: {
+    string_value: "${exclude_input_in_output}"
+  }
+}
+parameters: {
+  key: "cancellation_check_period_ms"
+  value: {
+    string_value: "${cancellation_check_period_ms}"
+  }
+}
+parameters: {
+  key: "stats_check_period_ms"
+  value: {
+    string_value: "${stats_check_period_ms}"
+  }
+}
+parameters: {
+  key: "iter_stats_max_iterations"
+  value: {
+    string_value: "${iter_stats_max_iterations}"
+  }
+}
+parameters: {
+  key: "request_stats_max_iterations"
+  value: {
+    string_value: "${request_stats_max_iterations}"
+  }
+}
+parameters: {
+  key: "enable_kv_cache_reuse"
+  value: {
+    string_value: "${enable_kv_cache_reuse}"
+  }
+}
+parameters: {
+  key: "normalize_log_probs"
+  value: {
+    string_value: "${normalize_log_probs}"
+  }
+}
+parameters: {
+  key: "enable_chunked_context"
+  value: {
+    string_value: "${enable_chunked_context}"
+  }
+}
+parameters: {
+  key: "gpu_device_ids"
+  value: {
+    string_value: "${gpu_device_ids}"
+  }
+}
+parameters: {
+  key: "participant_ids"
+  value: {
+    string_value: "${participant_ids}"
+  }
+}
+parameters: {
+  key: "lora_cache_optimal_adapter_size"
+  value: {
+    string_value: "${lora_cache_optimal_adapter_size}"
+  }
+}
+parameters: {
+  key: "lora_cache_max_adapter_size"
+  value: {
+    string_value: "${lora_cache_max_adapter_size}"
+  }
+}
+parameters: {
+  key: "lora_cache_gpu_memory_fraction"
+  value: {
+    string_value: "${lora_cache_gpu_memory_fraction}"
+  }
+}
+parameters: {
+  key: "lora_cache_host_memory_bytes"
+  value: {
+    string_value: "${lora_cache_host_memory_bytes}"
+  }
+}
+parameters: {
+  key: "lora_prefetch_dir"
+  value: {
+    string_value: "${lora_prefetch_dir}"
+  }
+}
+parameters: {
+  key: "decoding_mode"
+  value: {
+    string_value: "${decoding_mode}"
+  }
+}
+parameters: {
+  key: "executor_worker_path"
+  value: {
+    string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
+  }
+}
+parameters: {
+  key: "lookahead_window_size"
+    value: {
+      string_value: "${lookahead_window_size}"
+  }
+}
+parameters: {
+  key: "lookahead_ngram_size"
+    value: {
+      string_value: "${lookahead_ngram_size}"
+  }
+}
+parameters: {
+  key: "lookahead_verification_set_size"
+    value: {
+      string_value: "${lookahead_verification_set_size}"
+  }
+}
+parameters: {
+  key: "medusa_choices"
+    value: {
+      string_value: "${medusa_choices}"
+  }
+}
+parameters: {
+  key: "eagle_choices"
+    value: {
+      string_value: "${eagle_choices}"
+  }
+}
+parameters: {
+  key: "gpu_weights_percent"
+    value: {
+      string_value: "${gpu_weights_percent}"
+  }
+}
+parameters: {
+  key: "enable_context_fmha_fp32_acc"
+  value: {
+    string_value: "${enable_context_fmha_fp32_acc}"
+  }
+}
+parameters: {
+  key: "multi_block_mode"
+  value: {
+    string_value: "${multi_block_mode}"
+  }
+}
+parameters: {
+  key: "cuda_graph_mode"
+  value: {
+    string_value: "${cuda_graph_mode}"
+  }
+}
+parameters: {
+  key: "cuda_graph_cache_size"
+  value: {
+    string_value: "${cuda_graph_cache_size}"
+  }
+}
+parameters: {
+  key: "speculative_decoding_fast_logits"
+  value: {
+    string_value: "${speculative_decoding_fast_logits}"
+  }
+}
+parameters: {
+  key: "tokenizer_dir"
+  value: {
+    string_value: "${tokenizer_dir}"
+  }
+}
+parameters: {
+  key: "guided_decoding_backend"
+  value: {
+    string_value: "${guided_decoding_backend}"
+  }
+}
+parameters: {
+  key: "xgrammar_tokenizer_info_path"
+  value: {
+    string_value: "${xgrammar_tokenizer_info_path}"
+  }
+}

+ 195 - 0
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -0,0 +1,195 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import os
+
+import logging
+from typing import List, Dict
+
+import torch
+from torch.utils.dlpack import to_dlpack
+
+import triton_python_backend_utils as pb_utils
+
+from hyperpyyaml import load_hyperpyyaml
+from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
+from cosyvoice.utils.common import TrtContextWrapper
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+ORIGINAL_VOCAB_SIZE = 151663
+
+
+class CosyVoice2:
+
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
+
+        self.model_dir = model_dir
+        self.fp16 = fp16
+
+        hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
+        if not os.path.exists(hyper_yaml_path):
+            raise ValueError('{} not found!'.format(hyper_yaml_path))
+        with open(hyper_yaml_path, 'r') as f:
+            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
+        self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
+        self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
+        if load_jit:
+            self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
+        if load_trt:
+            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
+                                trt_concurrent,
+                                self.fp16)
+
+
+class CosyVoice2Model:
+
+    def __init__(self,
+                 flow: torch.nn.Module,
+                 hift: torch.nn.Module,
+                 fp16: bool = False):
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.flow = flow
+        self.hift = hift
+        self.fp16 = fp16
+        if self.fp16 is True:
+            self.flow.half()
+
+    def load_jit(self, flow_encoder_model):
+        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
+        self.flow.encoder = flow_encoder
+
+    def load(self, flow_model, hift_model):
+        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
+        self.flow.to(self.device).eval()
+        # in case hift_model is a hifigan model
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
+        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
+        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
+            convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_trt_kwargs(self):
+        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
+        opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
+        max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
+        input_names = ["x", "mask", "mu", "cond"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+
+class TritonPythonModel:
+    """Triton Python model for vocoder.
+
+    This model takes global and semantic tokens as input and generates audio waveforms
+    using the BiCodec vocoder.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        # Parse model parameters
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {key: value["string_value"] for key, value in parameters.items()}
+        model_dir = model_params["model_dir"]
+
+        # Initialize device and vocoder
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
+
+        self.token2wav_model = CosyVoice2(
+            model_dir, load_jit=True, load_trt=True, fp16=True
+        )
+
+        logger.info("Token2Wav initialized successfully")
+
+    def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated waveforms
+        """
+        responses = []
+        # Process each request in batch
+        for request in requests:
+            target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
+            prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
+            prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
+            prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
+
+            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
+            prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
+            prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
+            prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
+
+            # shift the speech tokens according to the original vocab size
+            prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
+            target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
+
+            tts_mel, _ = self.token2wav_model.model.flow.inference(
+                token=target_speech_tokens,
+                token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
+                    self.device
+                ),
+                prompt_token=prompt_speech_tokens,
+                prompt_token_len=torch.tensor(
+                    [prompt_speech_tokens.shape[1]], dtype=torch.int32
+                ).to(self.device),
+                prompt_feat=prompt_speech_feat,
+                prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
+                embedding=prompt_spk_embedding,
+                streaming=False,
+                finalize=True,
+            )
+
+            audio_hat, _ = self.token2wav_model.model.hift.inference(
+                speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+            )
+
+            generated_wave = audio_hat.squeeze(0).cpu().numpy()
+
+            wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
+            inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
+            responses.append(inference_response)
+
+        return responses

+ 63 - 0
runtime/triton_trtllm/model_repo/token2wav/config.pbtxt

@@ -0,0 +1,63 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "token2wav"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+parameters [
+  {
+   key: "model_dir", 
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "target_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+  },
+  {
+    name: "prompt_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+  },
+  {
+    name: "prompt_speech_feat"
+    data_type: TYPE_FP16
+    dims: [-1, 80]
+  },
+  {
+    name: "prompt_spk_embedding"
+    data_type: TYPE_FP16
+    dims: [-1]
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 14 - 0
runtime/triton_trtllm/requirements.txt

@@ -0,0 +1,14 @@
+hyperpyyaml
+s3tokenizer
+onnxruntime-gpu
+omegaconf
+conformer
+hydra-core
+lightning
+gdown
+wget
+librosa
+pyworld
+openai-whisper
+tritonclient
+modelscope

+ 105 - 0
runtime/triton_trtllm/run.sh

@@ -0,0 +1,105 @@
+
+export CUDA_VISIBLE_DEVICES=0
+cosyvoice_path=/workspace/CosyVoice
+export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
+export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
+stage=$1
+stop_stage=$2
+
+huggingface_model_local_dir=./cosyvoice2_llm
+model_scope_model_local_dir=./CosyVoice2-0.5B
+trt_dtype=bfloat16
+trt_weights_dir=./trt_weights_${trt_dtype}
+trt_engines_dir=./trt_engines_${trt_dtype}
+
+model_repo=./model_repo_cosyvoice2
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+    echo "Cloning CosyVoice"
+    git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
+    cd $cosyvoice_path
+    git submodule update --init --recursive
+    cd runtime/triton_trtllm
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+    echo "Downloading CosyVoice2-0.5B"
+    huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm 
+    modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir 
+fi
+
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+    echo "Converting checkpoint to TensorRT weights"
+    python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
+                                --output_dir $trt_weights_dir \
+                                --dtype $trt_dtype || exit 1
+
+    echo "Building TensorRT engines"
+    trtllm-build --checkpoint_dir $trt_weights_dir \
+                --output_dir $trt_engines_dir \
+                --max_batch_size 16 \
+                --max_num_tokens 32768 \
+                --gemm_plugin $trt_dtype || exit 1
+
+    echo "Testing TensorRT engines"
+    python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
+                    --tokenizer_dir $huggingface_model_local_dir \
+                    --top_k 50 --top_p 0.95 --temperature 0.8 \
+                    --engine_dir=$trt_engines_dir  || exit 1
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+    echo "Creating model repository"
+    rm -rf $model_repo
+    mkdir -p $model_repo
+    cosyvoice2_dir="cosyvoice2"
+
+    cp -r ./model_repo/${cosyvoice2_dir} $model_repo
+    cp -r ./model_repo/audio_tokenizer $model_repo
+    cp -r ./model_repo/tensorrt_llm $model_repo
+    cp -r ./model_repo/token2wav $model_repo
+
+    ENGINE_PATH=$trt_engines_dir
+    MAX_QUEUE_DELAY_MICROSECONDS=0
+    MODEL_DIR=$model_scope_model_local_dir
+    LLM_TOKENIZER_DIR=$huggingface_model_local_dir
+    BLS_INSTANCE_NUM=4
+    TRITON_MAX_BATCH_SIZE=16
+    DECOUPLED_MODE=False
+   
+    python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
+
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+   echo "Starting Triton server"
+   tritonserver --model-repository $model_repo
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+    echo "Single request test http"
+    python3 client_http.py \
+        --reference-audio ./assets/prompt_audio.wav \
+        --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
+        --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
+        --model-name cosyvoice2
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+    echo "Running benchmark client grpc"
+    num_task=4
+    # set mode=streaming, when decoupled=True
+    # set mode=offline, when decoupled=False
+    mode=offline
+    python3 client_grpc.py \
+        --server-addr localhost \
+        --model-name cosyvoice2 \
+        --num-tasks $num_task \
+        --mode $mode \
+        --huggingface-dataset yuekai/seed_tts_cosy2 \
+        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype}
+fi

+ 330 - 0
runtime/triton_trtllm/scripts/convert_checkpoint.py

@@ -0,0 +1,330 @@
+import argparse
+import os
+import time
+import traceback
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+from transformers import AutoConfig
+
+import tensorrt_llm
+from tensorrt_llm._utils import release_gc
+from tensorrt_llm.logger import logger
+from tensorrt_llm.mapping import Mapping
+from tensorrt_llm.models import QWenForCausalLM
+from tensorrt_llm.models.modeling_utils import QuantConfig
+from tensorrt_llm.quantization import QuantAlgo
+
+
+def parse_arguments():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model_dir', type=str, default=None, required=True)
+    parser.add_argument('--tp_size',
+                        type=int,
+                        default=1,
+                        help='N-way tensor parallelism size')
+    parser.add_argument('--pp_size',
+                        type=int,
+                        default=1,
+                        help='N-way pipeline parallelism size')
+    parser.add_argument('--cp_size',
+                        type=int,
+                        default=1,
+                        help='N-way context parallelism size')
+    parser.add_argument(
+        '--dtype',
+        type=str,
+        default='auto',
+        choices=['auto', 'float16', 'bfloat16', 'float32'],
+        help="The data type for the model weights and activations if not quantized. "
+        "If 'auto', the data type is automatically inferred from the source model; "
+        "however, if the source dtype is float32, it is converted to float16.")
+    parser.add_argument(
+        '--use_weight_only',
+        default=False,
+        action="store_true",
+        help='Quantize weights for the various GEMMs to INT4/INT8.'
+        'See --weight_only_precision to set the precision')
+    parser.add_argument(
+        '--disable_weight_only_quant_plugin',
+        default=False,
+        action="store_true",
+        help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
+        'You must also use --use_weight_only for that argument to have an impact.'
+    )
+    parser.add_argument(
+        '--weight_only_precision',
+        const='int8',
+        type=str,
+        nargs='?',
+        default='int8',
+        choices=['int8', 'int4', 'int4_gptq'],
+        help='Define the precision for the weights when using weight-only quantization.'
+        'You must also use --use_weight_only for that argument to have an impact.'
+    )
+    parser.add_argument(
+        '--calib_dataset',
+        type=str,
+        default='ccdv/cnn_dailymail',
+        help="The huggingface dataset name or the local directory of the dataset for calibration."
+    )
+    parser.add_argument(
+        "--smoothquant",
+        "-sq",
+        type=float,
+        default=None,
+        help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
+        " to Smoothquant the model, and output int8 weights."
+        " A good first try is 0.5. Must be in [0, 1]")
+    parser.add_argument(
+        '--per_channel',
+        action="store_true",
+        default=False,
+        help='By default, we use a single static scaling factor for the GEMM\'s result. '
+        'per_channel instead uses a different static scaling factor for each channel. '
+        'The latter is usually more accurate, but a little slower.')
+    parser.add_argument(
+        '--per_token',
+        action="store_true",
+        default=False,
+        help='By default, we use a single static scaling factor to scale activations in the int8 range. '
+        'per_token chooses at run time, and for each token, a custom scaling factor. '
+        'The latter is usually more accurate, but a little slower.')
+    parser.add_argument(
+        '--int8_kv_cache',
+        default=False,
+        action="store_true",
+        help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
+    )
+    parser.add_argument(
+        '--per_group',
+        default=False,
+        action="store_true",
+        help='By default, we use a single static scaling factor to scale weights in the int4 range. '
+        'per_group chooses at run time, and for each group, a custom scaling factor. '
+        'The flag is built for GPTQ/AWQ quantization.')
+
+    parser.add_argument('--group_size',
+                        type=int,
+                        default=128,
+                        help='Group size used in GPTQ quantization.')
+
+    parser.add_argument("--load_model_on_cpu", action="store_true")
+    parser.add_argument(
+        '--use_parallel_embedding',
+        action="store_true",
+        default=False,
+        help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
+    )
+    parser.add_argument(
+        '--embedding_sharding_dim',
+        type=int,
+        default=0,
+        choices=[0, 1],
+        help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
+        'To shard it along hidden dimension, set embedding_sharding_dim=1'
+        'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
+    )
+    parser.add_argument('--output_dir',
+                        type=str,
+                        default='tllm_checkpoint',
+                        help='The path to save the TensorRT-LLM checkpoint')
+    parser.add_argument(
+        '--workers',
+        type=int,
+        default=1,
+        help='The number of workers for converting checkpoint in parallel')
+    parser.add_argument(
+        '--moe_tp_size',
+        type=int,
+        default=-1,
+        help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
+    )
+    parser.add_argument(
+        '--moe_ep_size',
+        type=int,
+        default=-1,
+        help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
+    )
+    args = parser.parse_args()
+    return args
+
+
+def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
+    '''return config dict with quantization info based on the command line args
+    '''
+    quant_config = QuantConfig()
+    if args.use_weight_only:
+        if args.weight_only_precision == 'int8':
+            quant_config.quant_algo = QuantAlgo.W8A16
+        elif args.weight_only_precision == 'int4':
+            quant_config.quant_algo = QuantAlgo.W4A16
+    elif args.smoothquant:
+        quant_config.smoothquant_val = args.smoothquant
+        if args.per_channel:
+            if args.per_token:
+                quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
+            else:
+                quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
+        else:
+            if args.per_token:
+                quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
+            else:
+                quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
+
+    if args.int8_kv_cache:
+        quant_config.kv_cache_quant_algo = QuantAlgo.INT8
+
+    if args.weight_only_precision == 'int4_gptq':
+        quant_config.group_size = args.group_size
+        quant_config.has_zero_point = True
+        quant_config.pre_quant_scale = False
+        quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
+
+    return quant_config
+
+
+def update_quant_config_from_hf(quant_config, hf_config,
+                                override_fields) -> tuple[QuantConfig, dict]:
+    hf_config_dict = hf_config.to_dict()
+    if hf_config_dict.get('quantization_config'):
+        # update the quant_algo, and clamp_val.
+        if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
+            logger.info(
+                "Load quantization configs from huggingface model_config.")
+            quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
+            quant_config.group_size = hf_config_dict['quantization_config'].get(
+                'group_size', 128)
+            quant_config.has_zero_point = hf_config_dict[
+                'quantization_config'].get('zero_point', False)
+            override_fields.update({"use_autoawq": True})
+        elif hf_config_dict['quantization_config'].get(
+                'quant_method') == 'gptq':
+            logger.info(
+                "Load quantization configs from huggingface model_config.")
+            desc_act = hf_config_dict['quantization_config'].get(
+                'desc_act', False)
+            if desc_act:
+                raise ValueError("GPTQ with desc_act=True is not implemented!")
+            quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
+            quant_config.group_size = hf_config_dict['quantization_config'].get(
+                'group_size', 128)
+            quant_config.has_zero_point = hf_config_dict[
+                'quantization_config'].get('sym', False)
+    return quant_config, override_fields
+
+
+def args_to_build_options(args):
+    return {
+        'use_parallel_embedding': args.use_parallel_embedding,
+        'embedding_sharding_dim': args.embedding_sharding_dim,
+        'disable_weight_only_quant_plugin':
+        args.disable_weight_only_quant_plugin
+    }
+
+
+def convert_and_save_hf(args):
+    model_dir = args.model_dir
+    world_size = args.tp_size * args.pp_size
+    # Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
+    # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
+    # before the refactor is done.
+    override_fields = {}
+    override_fields.update(args_to_build_options(args))
+    quant_config = args_to_quant_config(args)
+
+    try:
+        hf_config = AutoConfig.from_pretrained(model_dir,
+                                               trust_remote_code=True)
+        quant_config, override_fields = update_quant_config_from_hf(
+            quant_config, hf_config, override_fields)
+    except BaseException:
+        logger.warning("AutoConfig cannot load the huggingface config.")
+
+    if args.smoothquant is not None or args.int8_kv_cache:
+        mapping = Mapping(world_size=world_size,
+                          tp_size=args.tp_size,
+                          pp_size=args.pp_size,
+                          moe_tp_size=args.moe_tp_size,
+                          moe_ep_size=args.moe_ep_size,
+                          cp_size=args.cp_size)
+        QWenForCausalLM.quantize(args.model_dir,
+                                 args.output_dir,
+                                 dtype=args.dtype,
+                                 mapping=mapping,
+                                 quant_config=quant_config,
+                                 calib_dataset=args.calib_dataset,
+                                 **override_fields)
+    else:
+
+        def convert_and_save_rank(args, rank):
+            mapping = Mapping(world_size=world_size,
+                              rank=rank,
+                              tp_size=args.tp_size,
+                              pp_size=args.pp_size,
+                              moe_tp_size=args.moe_tp_size,
+                              moe_ep_size=args.moe_ep_size)
+            qwen = QWenForCausalLM.from_hugging_face(model_dir,
+                                                     args.dtype,
+                                                     mapping=mapping,
+                                                     quant_config=quant_config,
+                                                     **override_fields)
+            qwen.config.mapping.cp_size = args.cp_size
+            qwen.config.mapping.attn_tp_size = -1
+            qwen.config.mapping.attn_cp_size = -1
+            qwen.config.mapping.world_size *= args.cp_size
+            qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
+            del qwen
+
+        execute(args.workers, [convert_and_save_rank] * world_size, args)
+        release_gc()
+
+
+def execute(workers, func, args):
+    if workers == 1:
+        for rank, f in enumerate(func):
+            f(args, rank)
+    else:
+        with ThreadPoolExecutor(max_workers=workers) as p:
+            futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
+            exceptions = []
+            for future in as_completed(futures):
+                try:
+                    future.result()
+                except Exception as e:
+                    traceback.print_exc()
+                    exceptions.append(e)
+            assert len(
+                exceptions
+            ) == 0, "Checkpoint conversion failed, please check error log."
+
+
+def main():
+    print(tensorrt_llm.__version__)
+    args = parse_arguments()
+
+    if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
+        # moe default to tp-only
+        args.moe_tp_size = args.tp_size
+        args.moe_ep_size = 1
+    elif (args.moe_tp_size == -1):
+        args.moe_tp_size = args.tp_size // args.moe_ep_size
+    elif (args.moe_ep_size == -1):
+        args.moe_ep_size = args.tp_size // args.moe_tp_size
+    assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
+            ), "moe_tp_size * moe_ep_size must equal to tp_size"
+
+    tik = time.time()
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+
+    assert args.model_dir is not None
+    convert_and_save_hf(args)
+
+    tok = time.time()
+    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
+    print(f'Total time of converting checkpoints: {t}')
+
+
+if __name__ == '__main__':
+    main()

+ 69 - 0
runtime/triton_trtllm/scripts/fill_template.py

@@ -0,0 +1,69 @@
+# /usr/bin/env python3
+from argparse import ArgumentParser
+from string import Template
+
+
+def split(string, delimiter):
+    """Split a string using delimiter. Supports escaping.
+
+    Args:
+        string (str): The string to split.
+        delimiter (str): The delimiter to split the string with.
+
+    Returns:
+        list: A list of strings.
+    """
+    result = []
+    current = ""
+    escape = False
+    for char in string:
+        if escape:
+            current += char
+            escape = False
+        elif char == delimiter:
+            result.append(current)
+            current = ""
+        elif char == "\\":
+            escape = True
+        else:
+            current += char
+    result.append(current)
+    return result
+
+
+def main(file_path, substitutions, in_place):
+    with open(file_path) as f:
+        pbtxt = Template(f.read())
+
+    sub_dict = {
+        "max_queue_size": 0,
+        'max_queue_delay_microseconds': 0,
+    }
+    for sub in split(substitutions, ","):
+        key, value = split(sub, ":")
+        sub_dict[key] = value
+
+        assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
+
+    pbtxt = pbtxt.safe_substitute(sub_dict)
+
+    if in_place:
+        with open(file_path, "w") as f:
+            f.write(pbtxt)
+    else:
+        print(pbtxt)
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("file_path", help="path of the .pbtxt to modify")
+    parser.add_argument(
+        "substitutions",
+        help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
+    )
+    parser.add_argument("--in_place",
+                        "-i",
+                        action="store_true",
+                        help="do the operation in-place")
+    args = parser.parse_args()
+    main(**vars(args))

+ 143 - 0
runtime/triton_trtllm/scripts/test_llm.py

@@ -0,0 +1,143 @@
+
+# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import ast
+import csv
+import os
+from pathlib import Path
+from typing import List, Optional
+
+import numpy as np
+import torch
+
+import tensorrt_llm
+from tensorrt_llm.logger import logger
+
+from tensorrt_llm.runtime import ModelRunnerCpp
+from transformers import AutoTokenizer
+
+
+def parse_arguments(args=None):
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--input_text',
+        type=str,
+        nargs='+',
+        default=["Born in north-east France, Soyer trained as a"])
+    parser.add_argument('--tokenizer_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
+    parser.add_argument('--engine_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
+    parser.add_argument('--log_level', type=str, default="debug")
+    parser.add_argument('--kv_cache_free_gpu_memory_fraction', type=float, default=0.6)
+    parser.add_argument('--temperature', type=float, default=0.8)
+    parser.add_argument('--top_k', type=int, default=50)
+    parser.add_argument('--top_p', type=float, default=0.95)
+
+    return parser.parse_args(args=args)
+
+
+def parse_input(tokenizer,
+                input_text=None,
+                prompt_template=None):
+    batch_input_ids = []
+    for curr_text in input_text:
+        if prompt_template is not None:
+            curr_text = prompt_template.format(input_text=curr_text)
+        input_ids = tokenizer.encode(
+            curr_text)
+        batch_input_ids.append(input_ids)
+
+    batch_input_ids = [
+        torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
+    ]
+
+    logger.debug(f"Input token ids (batch_size = {len(batch_input_ids)}):")
+    for i, input_ids in enumerate(batch_input_ids):
+        logger.debug(f"Request {i}: {input_ids.tolist()}")
+
+    return batch_input_ids
+
+
+def main(args):
+    runtime_rank = tensorrt_llm.mpi_rank()
+    logger.set_level(args.log_level)
+
+    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
+    prompt_template = "<|sos|>{input_text}<|task_id|>"
+    end_id = tokenizer.convert_tokens_to_ids("<|eos1|>")
+
+    batch_input_ids = parse_input(tokenizer=tokenizer,
+                                  input_text=args.input_text,
+                                  prompt_template=prompt_template)
+
+    input_lengths = [x.size(0) for x in batch_input_ids]
+
+    runner_kwargs = dict(
+        engine_dir=args.engine_dir,
+        rank=runtime_rank,
+        max_output_len=1024,
+        enable_context_fmha_fp32_acc=False,
+        max_batch_size=len(batch_input_ids),
+        max_input_len=max(input_lengths),
+        kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
+        cuda_graph_mode=False,
+        gather_generation_logits=False,
+    )
+
+    runner = ModelRunnerCpp.from_dir(**runner_kwargs)
+
+    with torch.no_grad():
+        outputs = runner.generate(
+            batch_input_ids=batch_input_ids,
+            max_new_tokens=1024,
+            end_id=end_id,
+            pad_id=end_id,
+            temperature=args.temperature,
+            top_k=args.top_k,
+            top_p=args.top_p,
+            num_return_sequences=1,
+            repetition_penalty=1.1,
+            random_seed=42,
+            streaming=False,
+            output_sequence_lengths=True,
+            output_generation_logits=False,
+            return_dict=True,
+            return_all_generated_tokens=False)
+        torch.cuda.synchronize()
+        output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
+        num_output_sents, num_beams, _ = output_ids.size()
+        assert num_beams == 1
+        beam = 0
+        batch_size = len(input_lengths)
+        num_return_sequences = num_output_sents // batch_size
+        assert num_return_sequences == 1
+        for i in range(batch_size * num_return_sequences):
+            batch_idx = i // num_return_sequences
+            seq_idx = i % num_return_sequences
+            inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
+            input_text = tokenizer.decode(inputs)
+            print(f'Input [Text {batch_idx}]: \"{input_text}\"')
+            output_begin = input_lengths[batch_idx]
+            output_end = sequence_lengths[i][beam]
+            outputs = output_ids[i][beam][output_begin:output_end].tolist()
+            output_text = tokenizer.decode(outputs)
+            print(f'Output [Text {batch_idx}]: \"{output_text}\"')
+            logger.debug(str(outputs))
+
+
+if __name__ == '__main__':
+    args = parse_arguments()
+    main(args)