Selaa lähdekoodia

Merge pull request #1598 from yuekaizhang/streaming

[Runtime] StepAudio2 Streaming DiT Token2Wav Integration
Xiang Lyu 1 kuukausi sitten
vanhempi
commit
6e01309e01

+ 2 - 2
examples/grpo/cosyvoice2/README.md

@@ -36,7 +36,7 @@ Stage `0` converts raw JSONL files into the parquet format expected by veRL:
 ```bash
 bash run.sh 0 0
 ```
-Create two JSONL files—`train.jsonl` and `test.jsonl`.  
+Create two JSONL files—`train.jsonl` and `test.jsonl`.
 The script will then generate two Parquet files:
 
 ```
@@ -111,7 +111,7 @@ bash run.sh 5 5
 
 The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
 > [!TIP]
->  However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. 
+>  However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
 
 ## Results
 

+ 1 - 1
examples/grpo/cosyvoice2/infer_dataset.py

@@ -53,7 +53,7 @@ except RuntimeError:
     pass
 
 
-TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
+TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"  # noqa: E501
 
 
 def audio_decode_cosyvoice2(

+ 0 - 2
examples/grpo/cosyvoice2/pretrained_to_huggingface.py

@@ -1,5 +1,3 @@
-#!/usr/bin/env python3
-
 # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 # SPDX-License-Identifier: Apache-2.0
 #

+ 3 - 3
examples/grpo/cosyvoice2/run.sh

@@ -33,7 +33,7 @@ fi
 
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
   log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
-  modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path 
+  modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
   python3 pretrained_to_huggingface.py \
     --pretrained-cosyvoice2-path $model_scope_model_path \
     --save-path $sft_model_path
@@ -61,7 +61,7 @@ fi
 if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
   log "stage 1: start token2wav asr server for reward function"
   python3 token2wav_asr_server.py --number-of-devices 8
-fi 
+fi
 
 exp_name=official_llm_aishell3_grpo
 if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@@ -125,7 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
       --backend fsdp \
       --local_dir $llm_path/actor \
       --target_dir $llm_path/merged_hf_model || exit 1
-fi 
+fi
 
 if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
   log "stage 4: Test the model"

+ 1 - 3
examples/grpo/cosyvoice2/scripts/offline-decode-files.py

@@ -1,5 +1,3 @@
-#!/usr/bin/env python3
-#
 # Copyright (c)  2023 by manyeyes
 # Copyright (c)  2023  Xiaomi Corporation
 
@@ -195,7 +193,7 @@ def write_error_stats(
             hyp = list("".join(hyp))
             results[i] = (cut_id, ref, hyp)
 
-    for cut_id, ref, hyp in results:
+    for _cut_id, ref, hyp in results:
         ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
         for ref_word, hyp_word in ali:
             if ref_word == ERR:

+ 1 - 1
examples/grpo/cosyvoice2/token2wav_asr_server.py

@@ -295,7 +295,7 @@ def main():
         metrics_port=8002,
     )
 
-    device_ids = [i for i in range(args.number_of_devices)]
+    device_ids = list(range(args.number_of_devices))
     device_ids = device_ids * args.number_of_instances_per_device
 
     with Triton(config=triton_config) as triton:

+ 141 - 0
runtime/triton_trtllm/README.DIT.md

@@ -0,0 +1,141 @@
+## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM
+
+Contributed by Yuekai Zhang (NVIDIA).
+
+This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM.
+
+### Quick Start
+
+Launch the service directly with Docker Compose:
+```sh
+docker compose -f docker-compose.dit.yml up
+```
+
+### Build the Docker Image
+
+To build the image from scratch:
+```sh
+docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
+```
+
+### Run a Docker Container
+```sh
+your_mount_dir=/mnt:/mnt
+docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
+```
+
+### Understanding `run_stepaudio2_dit_token2wav.sh`
+
+The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages.
+
+You can run a subset of stages with:
+```sh
+bash run_stepaudio2_dit_token2wav.sh <start_stage> <stop_stage>
+```
+- `<start_stage>`: The stage to start from.
+- `<stop_stage>`: The stage to stop after.
+
+**Stages:**
+
+- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories.
+- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models.
+- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines.
+- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`.
+- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM.
+- **Stage 4**: Runs the gRPC benchmark client for performance testing.
+- **Stage 5**: Runs the offline TTS inference benchmark test.
+- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model.
+- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3.
+- **Stage 8**: Runs the benchmark client for the disaggregated server configuration.
+### Export Models and Launch Server
+
+Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
+```sh
+# This command runs stages 0, 1, 2, and 3
+bash run_stepaudio2_dit_token2wav.sh 0 3
+```
+
+### Benchmark with client-server mode
+
+To benchmark the running Triton server, run stage 4:
+```sh
+bash run_stepaudio2_dit_token2wav.sh 4 4
+
+# You can customize parameters such as the number of tasks inside the script.
+```
+The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
+
+#### Total Request Latency
+
+| Concurrent Tasks | RTF    | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
+| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
+| 1                | 0.1228 | 833.66       | 779.98               | 1297.05              | 1555.97              | 1653.02              |
+| 2                | 0.0901 | 1166.23      | 1124.69              | 1762.76              | 1900.64              | 2204.14              |
+| 4                | 0.0741 | 1849.30      | 1759.42              | 2624.50              | 2822.20              | 3128.42              |
+| 6                | 0.0774 | 2936.13      | 3054.64              | 3849.60              | 3900.49              | 4245.79              |
+| 8                | 0.0691 | 3408.56      | 3434.98              | 4547.13              | 5047.76              | 5346.53              |
+| 10               | 0.0707 | 4306.56      | 4343.44              | 5769.64              | 5876.09              | 5939.79              |
+
+#### First Chunk Latency
+
+| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
+| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
+| 1                | 197.50       | 196.13               | 214.65               | 215.96               | 229.21               |
+| 2                |  281.15       | 278.20               | 345.18               | 361.79               | 395.97               |
+| 4                |  510.65       | 530.50               | 630.13               | 642.44               | 666.65               |
+| 6                |  921.54       | 918.86               | 1079.97              | 1265.22              | 1524.41              |
+| 8                |  1019.95      | 1085.26              | 1371.05              | 1402.24              | 1410.66              |
+| 10               |  1214.98      | 1293.54              | 1575.36              | 1654.51              | 2161.76              |
+
+### Benchmark with offline inference mode
+For offline inference mode benchmark, please run stage 5:
+```sh
+bash run_stepaudio2_dit_token2wav.sh 5 5
+```
+
+The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
+
+#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav)
+| Backend | Batch Size | llm_time_seconds  | total_time_seconds | RTF |
+|---------|------------|------------------|-----------------------|--|
+| TRTLLM | 16 | 2.01 |  5.03 | 0.0292 |
+
+
+### Disaggregated Server
+When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs.
+
+The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU.
+
+| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) |
+|---|---|---|---|---|---|---|
+| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 |
+| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 |
+| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 |
+| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 |
+| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 |
+| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 |
+| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 |
+| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 |
+| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 |
+| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 |
+| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 |
+| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 |
+| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 |
+| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 |
+| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 |
+| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 |
+| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 |
+| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 |
+| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 |
+| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 |
+| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 |
+| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 |
+| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 |
+| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 |
+
+The `concurrent_task_per_gpu` is calculated as:
+`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))`
+
+### Acknowledgements
+
+This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

+ 187 - 121
runtime/triton_trtllm/client_grpc.py

@@ -43,9 +43,9 @@ python3 client_grpc.py \
 import argparse
 import asyncio
 import json
-import queue  # Added
-import uuid  # Added
-import functools  # Added
+import queue
+import uuid
+import functools
 
 import os
 import time
@@ -55,16 +55,16 @@ from pathlib import Path
 import numpy as np
 import soundfile as sf
 import tritonclient
-import tritonclient.grpc.aio as grpcclient_aio  # Renamed original import
-import tritonclient.grpc as grpcclient_sync  # Added sync client import
-from tritonclient.utils import np_to_triton_dtype, InferenceServerException  # Added InferenceServerException
+import tritonclient.grpc.aio as grpcclient_aio
+import tritonclient.grpc as grpcclient_sync
+from tritonclient.utils import np_to_triton_dtype, InferenceServerException
 
 
-# --- Added UserData and callback ---
 class UserData:
     def __init__(self):
         self._completed_requests = queue.Queue()
         self._first_chunk_time = None
+        self._second_chunk_time = None
         self._start_time = None
 
     def record_start_time(self):
@@ -75,39 +75,43 @@ class UserData:
             return self._first_chunk_time - self._start_time
         return None
 
+    def get_second_chunk_latency(self):
+        if self._first_chunk_time and self._second_chunk_time:
+            return self._second_chunk_time - self._first_chunk_time
+        return None
+
 
 def callback(user_data, result, error):
-    if user_data._first_chunk_time is None and not error:
-        user_data._first_chunk_time = time.time()  # Record time of first successful chunk
+    if not error:
+        if user_data._first_chunk_time is None:
+            user_data._first_chunk_time = time.time()
+        elif user_data._second_chunk_time is None:
+            user_data._second_chunk_time = time.time()
+
     if error:
         user_data._completed_requests.put(error)
     else:
         user_data._completed_requests.put(result)
-# --- End Added UserData and callback ---
+
+
+def stream_callback(user_data_map, result, error):
+    request_id = None
+    if error:
+        print(f"An error occurred in the stream callback: {error}")
+    else:
+        request_id = result.get_response().id
+
+    if request_id:
+        user_data = user_data_map.get(request_id)
+        if user_data:
+            callback(user_data, result, error)
+        else:
+            print(f"Warning: Could not find user_data for request_id {request_id}")
 
 
 def write_triton_stats(stats, summary_file):
     with open(summary_file, "w") as summary_f:
         model_stats = stats["model_stats"]
-        # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
-        summary_f.write(
-            "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
-        )
-        summary_f.write("To learn more about the log, please refer to: \n")
-        summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
-        summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
-        summary_f.write(
-            "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
-        )
-        summary_f.write(
-            "However, there is a trade-off between the increased queue time and the increased batch size. \n"
-        )
-        summary_f.write(
-            "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
-        )
-        summary_f.write(
-            "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
-        )
         for model_state in model_stats:
             if "last_inference" not in model_state:
                 continue
@@ -118,7 +122,10 @@ def write_triton_stats(stats, summary_file):
             total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
             total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
             summary_f.write(
-                f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"  # noqa
+                f"queue time {total_queue_time_s:<5.2f} s, "
+                f"compute infer time {total_infer_time_s:<5.2f} s, "
+                f"compute input time {total_input_time_s:<5.2f} s, "
+                f"compute output time {total_output_time_s:<5.2f} s \n"
             )
             model_batch_stats = model_state["batch_stats"]
             for batch in model_batch_stats:
@@ -127,21 +134,86 @@ def write_triton_stats(stats, summary_file):
                 compute_output = batch["compute_output"]
                 compute_infer = batch["compute_infer"]
                 batch_count = int(compute_infer["count"])
+                if batch_count == 0:
+                    continue
                 assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
                 compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
                 compute_input_time_ms = int(compute_input["ns"]) / 1e6
                 compute_output_time_ms = int(compute_output["ns"]) / 1e6
                 summary_f.write(
-                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"  # noqa
+                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
+                    f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
+                    f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
+                    f"{compute_infer_time_ms / batch_count:.2f} ms, "
+                    f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
+                    f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
                 )
                 summary_f.write(
-                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "  # noqa
+                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
                 )
                 summary_f.write(
-                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"  # noqa
+                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
                 )
 
 
+def subtract_stats(stats_after, stats_before):
+    """Subtracts two Triton inference statistics objects."""
+    stats_diff = json.loads(json.dumps(stats_after))
+
+    model_stats_before_map = {
+        s["name"]: {
+            "version": s["version"],
+            "last_inference": s.get("last_inference", 0),
+            "inference_count": s.get("inference_count", 0),
+            "execution_count": s.get("execution_count", 0),
+            "inference_stats": s.get("inference_stats", {}),
+            "batch_stats": s.get("batch_stats", []),
+        }
+        for s in stats_before["model_stats"]
+    }
+
+    for model_stat_after in stats_diff["model_stats"]:
+        model_name = model_stat_after["name"]
+        if model_name in model_stats_before_map:
+            model_stat_before = model_stats_before_map[model_name]
+
+            model_stat_after["inference_count"] = str(
+                int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
+            )
+            model_stat_after["execution_count"] = str(
+                int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
+            )
+
+            if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
+                for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
+                    if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
+                        if "ns" in model_stat_after["inference_stats"][key]:
+                            ns_after = int(model_stat_after["inference_stats"][key]["ns"])
+                            ns_before = int(model_stat_before["inference_stats"][key]["ns"])
+                            model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
+                        if "count" in model_stat_after["inference_stats"][key]:
+                            count_after = int(model_stat_after["inference_stats"][key]["count"])
+                            count_before = int(model_stat_before["inference_stats"][key]["count"])
+                            model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
+
+            if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
+                batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
+                for batch_stat_after in model_stat_after["batch_stats"]:
+                    bs = batch_stat_after["batch_size"]
+                    if bs in batch_stats_before_map:
+                        batch_stat_before = batch_stats_before_map[bs]
+                        for key in ["compute_input", "compute_infer", "compute_output"]:
+                            if key in batch_stat_after and key in batch_stat_before:
+                                count_after = int(batch_stat_after[key]["count"])
+                                count_before = int(batch_stat_before[key]["count"])
+                                batch_stat_after[key]["count"] = str(count_after - count_before)
+
+                                ns_after = int(batch_stat_after[key]["ns"])
+                                ns_before = int(batch_stat_before[key]["ns"])
+                                batch_stat_after[key]["ns"] = str(ns_after - ns_before)
+    return stats_diff
+
+
 def get_args():
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
@@ -209,7 +281,8 @@ def get_args():
         choices=[
             "f5_tts",
             "spark_tts",
-            "cosyvoice2"],
+            "cosyvoice2",
+            "cosyvoice2_dit"],
         help="triton model_repo module name to request",
     )
 
@@ -243,7 +316,6 @@ def get_args():
         help="log directory",
     )
 
-    # --- Added arguments ---
     parser.add_argument(
         "--mode",
         type=str,
@@ -260,8 +332,8 @@ def get_args():
 
     parser.add_argument(
         "--use-spk2info-cache",
-        type=bool,
-        default=False,
+        type=str,
+        default="False",
         help="Use spk2info cache for reference audio.",
     )
 
@@ -284,39 +356,33 @@ def load_audio(wav_path, target_sample_rate=16000):
 
 
 def prepare_request_input_output(
-    protocol_client,  # Can be grpcclient_aio or grpcclient_sync
+    protocol_client,
     waveform,
     reference_text,
     target_text,
     sample_rate=16000,
-    padding_duration: int = None,  # Optional padding for offline mode
+    padding_duration: int = None,
     use_spk2info_cache: bool = False
 ):
     """Prepares inputs for Triton inference (offline or streaming)."""
     assert len(waveform.shape) == 1, "waveform should be 1D"
     lengths = np.array([[len(waveform)]], dtype=np.int32)
 
-    # Apply padding only if padding_duration is provided (for offline)
     if padding_duration:
         duration = len(waveform) / sample_rate
-        # Estimate target duration based on text length ratio (crude estimation)
-        # Avoid division by zero if reference_text is empty
         if reference_text:
             estimated_target_duration = duration / len(reference_text) * len(target_text)
         else:
-            estimated_target_duration = duration  # Assume target duration similar to reference if no text
+            estimated_target_duration = duration
 
-        # Calculate required samples based on estimated total duration
         required_total_samples = padding_duration * sample_rate * (
             (int(estimated_target_duration + duration) // padding_duration) + 1
         )
         samples = np.zeros((1, required_total_samples), dtype=np.float32)
         samples[0, : len(waveform)] = waveform
     else:
-        # No padding for streaming or if padding_duration is None
         samples = waveform.reshape(1, -1).astype(np.float32)
 
-    # Common input creation logic
     inputs = [
         protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
         protocol_client.InferInput(
@@ -355,12 +421,8 @@ def run_sync_streaming_inference(
 ):
     """Helper function to run the blocking sync streaming call."""
     start_time_total = time.time()
-    user_data.record_start_time()  # Record start time for first chunk latency calculation
-
-    # Establish stream
-    sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
+    user_data.record_start_time()
 
-    # Send request
     sync_triton_client.async_stream_infer(
         model_name,
         inputs,
@@ -369,91 +431,76 @@ def run_sync_streaming_inference(
         enable_empty_final_response=True,
     )
 
-    # Process results
     audios = []
     while True:
         try:
-            result = user_data._completed_requests.get()  # Add timeout
+            result = user_data._completed_requests.get(timeout=200)
             if isinstance(result, InferenceServerException):
                 print(f"Received InferenceServerException: {result}")
-                sync_triton_client.stop_stream()
-                return None, None, None  # Indicate error
-            # Get response metadata
+                return None, None, None, None
             response = result.get_response()
             final = response.parameters["triton_final_response"].bool_param
             if final is True:
                 break
 
             audio_chunk = result.as_numpy("waveform").reshape(-1)
-            if audio_chunk.size > 0:  # Only append non-empty chunks
+            if audio_chunk.size > 0:
                 audios.append(audio_chunk)
             else:
                 print("Warning: received empty audio chunk.")
 
         except queue.Empty:
             print(f"Timeout waiting for response for request id {request_id}")
-            sync_triton_client.stop_stream()
-            return None, None, None  # Indicate error
+            return None, None, None, None
 
-    sync_triton_client.stop_stream()
     end_time_total = time.time()
     total_request_latency = end_time_total - start_time_total
     first_chunk_latency = user_data.get_first_chunk_latency()
+    second_chunk_latency = user_data.get_second_chunk_latency()
 
-    # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
-    actual_duration = 0
     if audios:
-        # Only spark_tts model uses cross-fade
         if model_name == "spark_tts":
             cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
             fade_out = np.linspace(1, 0, cross_fade_samples)
             fade_in = np.linspace(0, 1, cross_fade_samples)
             reconstructed_audio = None
 
-            # Simplified reconstruction based on client_grpc_streaming.py
             if not audios:
                 print("Warning: No audio chunks received.")
-                reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
+                reconstructed_audio = np.array([], dtype=np.float32)
             elif len(audios) == 1:
                 reconstructed_audio = audios[0]
             else:
-                reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
+                reconstructed_audio = audios[0][:-cross_fade_samples]
                 for i in range(1, len(audios)):
-                    # Cross-fade section
                     cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
                                            audios[i - 1][-cross_fade_samples:] * fade_out)
-                    # Middle section of the current chunk
                     middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
-                    # Concatenate
                     reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
-                # Add the last part of the final chunk
                 reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
 
             if reconstructed_audio is not None and reconstructed_audio.size > 0:
                 actual_duration = len(reconstructed_audio) / save_sample_rate
-                # Save reconstructed audio
                 sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
             else:
                 print("Warning: No audio chunks received or reconstructed.")
-                actual_duration = 0  # Set duration to 0 if no audio
+                actual_duration = 0
         else:
             reconstructed_audio = np.concatenate(audios)
-            print(f"reconstructed_audio: {reconstructed_audio.shape}")
             actual_duration = len(reconstructed_audio) / save_sample_rate
-            # Save reconstructed audio
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
 
     else:
         print("Warning: No audio chunks received.")
         actual_duration = 0
 
-    return total_request_latency, first_chunk_latency, actual_duration
+    return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
 
 
 async def send_streaming(
     manifest_item_list: list,
     name: str,
-    server_url: str,  # Changed from sync_triton_client
+    server_url: str,
     protocol_client: types.ModuleType,
     log_interval: int,
     model_name: str,
@@ -466,11 +513,13 @@ async def send_streaming(
     total_duration = 0.0
     latency_data = []
     task_id = int(name[5:])
-    sync_triton_client = None  # Initialize client variable
+    sync_triton_client = None
+    user_data_map = {}
 
-    try:  # Wrap in try...finally to ensure client closing
+    try:
         print(f"{name}: Initializing sync client for streaming...")
-        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)  # Create client here
+        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
+        sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
 
         print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
         for i, item in enumerate(manifest_item_list):
@@ -490,12 +539,13 @@ async def send_streaming(
                     padding_duration=padding_duration,
                     use_spk2info_cache=use_spk2info_cache
                 )
+
                 request_id = str(uuid.uuid4())
                 user_data = UserData()
+                user_data_map[request_id] = user_data
 
                 audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
-
-                total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
+                total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
                     run_sync_streaming_inference,
                     sync_triton_client,
                     model_name,
@@ -509,12 +559,18 @@ async def send_streaming(
                 )
 
                 if total_request_latency is not None:
-                    print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
-                    latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
+                    print(
+                        f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
+                        f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
+                        f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
+                    )
+                    latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
                     total_duration += actual_duration
                 else:
                     print(f"{name}: Item {i} failed.")
 
+                del user_data_map[request_id]
+
             except FileNotFoundError:
                 print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
             except Exception as e:
@@ -522,10 +578,11 @@ async def send_streaming(
                 import traceback
                 traceback.print_exc()
 
-    finally:  # Ensure client is closed
+    finally:
         if sync_triton_client:
             try:
-                print(f"{name}: Closing sync client...")
+                print(f"{name}: Closing stream and sync client...")
+                sync_triton_client.stop_stream()
                 sync_triton_client.close()
             except Exception as e:
                 print(f"{name}: Error closing sync client: {e}")
@@ -550,7 +607,6 @@ async def send(
     latency_data = []
     task_id = int(name[5:])
 
-    print(f"manifest_item_list: {manifest_item_list}")
     for i, item in enumerate(manifest_item_list):
         if i % log_interval == 0:
             print(f"{name}: {i}/{len(manifest_item_list)}")
@@ -591,7 +647,6 @@ def load_manifests(manifest_path):
             assert len(line.strip().split("|")) == 4
             utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
             utt = Path(utt).stem
-            # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
             if not os.path.isabs(prompt_wav):
                 prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
             manifest_list.append(
@@ -632,23 +687,17 @@ async def main():
     args = get_args()
     url = f"{args.server_addr}:{args.server_port}"
 
-    # --- Client Initialization based on mode ---
     triton_client = None
     protocol_client = None
     if args.mode == "offline":
         print("Initializing gRPC client for offline mode...")
-        # Use the async client for offline tasks
         triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
         protocol_client = grpcclient_aio
     elif args.mode == "streaming":
         print("Initializing gRPC client for streaming mode...")
-        # Use the sync client for streaming tasks, handled via asyncio.to_thread
-        # We will create one sync client instance PER TASK inside send_streaming.
-        # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
-        protocol_client = grpcclient_sync  # protocol client for input prep
+        protocol_client = grpcclient_sync
     else:
         raise ValueError(f"Invalid mode: {args.mode}")
-    # --- End Client Initialization ---
 
     if args.reference_audio:
         args.num_tasks = 1
@@ -682,15 +731,24 @@ async def main():
     else:
         manifest_item_list = load_manifests(args.manifest_path)
 
+    stats_client = None
+    stats_before = None
+    try:
+        print("Initializing temporary async client for fetching stats...")
+        stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
+        print("Fetching inference statistics before running tasks...")
+        stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
+    except Exception as e:
+        print(f"Could not retrieve statistics before running tasks: {e}")
+
     num_tasks = min(args.num_tasks, len(manifest_item_list))
     manifest_item_list = split_data(manifest_item_list, num_tasks)
 
     os.makedirs(args.log_dir, exist_ok=True)
-
+    args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
     tasks = []
     start_time = time.time()
     for i in range(num_tasks):
-        # --- Task Creation based on mode ---
         if args.mode == "offline":
             task = asyncio.create_task(
                 send(
@@ -711,7 +769,7 @@ async def main():
                 send_streaming(
                     manifest_item_list[i],
                     name=f"task-{i}",
-                    server_url=url,  # Pass URL instead of client
+                    server_url=url,
                     protocol_client=protocol_client,
                     log_interval=args.log_interval,
                     model_name=args.model_name,
@@ -722,7 +780,6 @@ async def main():
                     use_spk2info_cache=args.use_spk2info_cache,
                 )
             )
-        # --- End Task Creation ---
         tasks.append(task)
 
     ans_list = await asyncio.gather(*tasks)
@@ -735,7 +792,7 @@ async def main():
     for ans in ans_list:
         if ans:
             total_duration += ans[0]
-            latency_data.extend(ans[1])  # Use extend for list of lists
+            latency_data.extend(ans[1])
         else:
             print("Warning: A task returned None, possibly due to an error.")
 
@@ -751,10 +808,8 @@ async def main():
     s += f"({total_duration / 3600:.2f} hours)\n"
     s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
 
-    # --- Statistics Reporting based on mode ---
     if latency_data:
         if args.mode == "offline":
-            # Original offline latency calculation
             latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
             if latency_list:
                 latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
@@ -769,9 +824,9 @@ async def main():
                 s += "No latency data collected for offline mode.\n"
 
         elif args.mode == "streaming":
-            # Calculate stats for total request latency and first chunk latency
-            total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
-            first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
+            total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
+            first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
+            second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
 
             s += "\n--- Total Request Latency ---\n"
             if total_latency_list:
@@ -798,9 +853,21 @@ async def main():
                 s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
             else:
                 s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
+
+            s += "\n--- Second Chunk Latency ---\n"
+            if second_chunk_latency_list:
+                avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
+                variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
+                s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
+                s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
+                s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
+                s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
+                s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
+                s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
+            else:
+                s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
     else:
         s += "No latency data collected.\n"
-    # --- End Statistics Reporting ---
 
     print(s)
     if args.manifest_path:
@@ -810,26 +877,27 @@ async def main():
     elif args.reference_audio:
         name = Path(args.reference_audio).stem
     else:
-        name = "results"  # Default name if no manifest/split/audio provided
+        name = "results"
     with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
         f.write(s)
 
-    # --- Statistics Fetching using temporary Async Client ---
-    # Use a separate async client for fetching stats regardless of mode
-    stats_client = None
     try:
-        print("Initializing temporary async client for fetching stats...")
-        stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
-        print("Fetching inference statistics...")
-        # Fetching for all models, filtering might be needed depending on server setup
-        stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
-        print("Fetching model config...")
-        metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
+        if stats_client and stats_before:
+            print("Fetching inference statistics after running tasks...")
+            stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
+
+            print("Calculating statistics difference...")
+            stats = subtract_stats(stats_after, stats_before)
 
-        write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
+            print("Fetching model config...")
+            metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
 
-        with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
-            json.dump(metadata, f, indent=4)
+            write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
+
+            with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
+                json.dump(metadata, f, indent=4)
+        else:
+            print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
 
     except Exception as e:
         print(f"Could not retrieve statistics or config: {e}")
@@ -840,11 +908,9 @@ async def main():
                 await stats_client.close()
             except Exception as e:
                 print(f"Error closing async stats client: {e}")
-    # --- End Statistics Fetching ---
 
 
 if __name__ == "__main__":
-    # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
     async def run_main():
         try:
             await main()

+ 0 - 1
runtime/triton_trtllm/client_http.py

@@ -25,7 +25,6 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 import requests
 import soundfile as sf
-import json
 import numpy as np
 import argparse
 

+ 20 - 0
runtime/triton_trtllm/docker-compose.dit.yml

@@ -0,0 +1,20 @@
+services:
+  tts:
+    image: soar97/triton-cosyvoice:25.06
+    shm_size: '1gb'
+    ports:
+      - "8000:8000"
+      - "8001:8001"
+      - "8002:8002"
+    environment:
+      - PYTHONIOENCODING=utf-8
+      - MODEL_ID=${MODEL_ID}
+    deploy:
+      resources:
+        reservations:
+          devices:
+            - driver: nvidia
+              device_ids: ['0']
+              capabilities: [gpu]
+    command: >
+      /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3"

+ 0 - 3
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -25,12 +25,9 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import json
-import math
 import os
-import re
 import threading
 import time
-from typing import Dict, List, Tuple, Optional, Union
 
 import numpy as np
 import torch

+ 394 - 0
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py

@@ -0,0 +1,394 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import math
+import os
+import re
+import time
+from typing import Dict, List, Tuple, Optional, Union
+import asyncio
+import httpx
+
+import numpy as np
+import torch
+from torch.utils.dlpack import from_dlpack, to_dlpack
+import triton_python_backend_utils as pb_utils
+from transformers import AutoTokenizer
+
+import torchaudio
+
+
+from matcha.utils.audio import mel_spectrogram
+
+
+ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
+
+
+def parse_speech_token_string(response_text: str) -> List[int]:
+    """
+    Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
+    """
+    speech_tokens = response_text.strip().split('><')
+    if len(speech_tokens) > 1:
+        # Add back the missing '<' and '>' for proper parsing
+        speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
+        speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
+
+    speech_ids = []
+    for token_str in speech_tokens:
+        match = re.match(r'<\|s_(\d+)\|>', token_str)
+        if match:
+            speech_ids.append(int(match.group(1)))
+    return speech_ids
+
+
+class TritonPythonModel:
+    """Triton Python model for Spark TTS.
+
+    This model orchestrates the end-to-end TTS pipeline by coordinating
+    between audio tokenizer, LLM, and vocoder components.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        self.logger = pb_utils.Logger
+        # Parse model parameters
+        self.model_config = json.loads(args['model_config'])
+        parameters = self.model_config['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+        self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
+        self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
+
+        # Initialize tokenizer
+        llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
+        self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
+        self.prompt_template = "<|sos|>{input_text}<|task_id|>"
+        self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
+
+        self.device = torch.device("cuda")
+        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
+
+        self.token_frame_rate = 25
+        self.flow_pre_lookahead_len = 3
+        self.token_hop_len = 15
+
+        self.http_client = httpx.AsyncClient()
+        self.api_base = "http://localhost:8000/v1/chat/completions"
+        self.speaker_cache = {}
+
+    def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
+        """Converts a tensor or list of speech token IDs to a string representation."""
+        if isinstance(speech_tokens, torch.Tensor):
+            # Ensure tensor is on CPU and flattened
+            speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
+
+        speech_id_str = ""
+        for token_id in speech_tokens:
+            # Convert token ID back to the speech number N
+            token_num = token_id - ORIGINAL_VOCAB_SIZE
+            speech_id_str += f"<|s_{token_num}|>"
+        return speech_id_str
+
+    async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
+        """
+        Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
+        """
+        full_text = f"{reference_text}{target_text}"
+        prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
+
+        chat = [
+            {"role": "user", "content": full_text},
+            {"role": "assistant", "content": prompt_speech_tokens_str}
+        ]
+
+        payload = {
+            "model": "trt_engines_bfloat16",
+            "messages": chat,
+            "max_tokens": 750,
+            "temperature": 0.8,
+            "top_p": 0.95,
+            "top_k": 50,
+            "repetition_penalty": 1.1,
+            "stop": ["<|eos1|>", "<|eos|>"],
+            "stream": True,
+        }
+
+        buffer = ""
+        async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
+            response.raise_for_status()
+            async for line in response.aiter_lines():
+                if line.startswith("data: "):
+                    line_data = line[len("data: "):].strip()
+                    if line_data == "[DONE]":
+                        break
+                    try:
+                        json_data = json.loads(line_data)
+                        content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
+                        if content:
+                            buffer += content
+                            while True:
+                                match = re.search(r"<\|s_(\d+)\|>", buffer)
+                                if not match:
+                                    break
+
+                                token_num = int(match.group(1))
+                                final_id = token_num + ORIGINAL_VOCAB_SIZE
+                                yield final_id
+                                buffer = buffer[match.end():]
+                    except json.JSONDecodeError:
+                        self.logger.log_info(f"Skipping non-JSON line: {line_data}")
+                        continue
+
+        # Process any remaining complete tokens in the buffer after the stream ends
+        while True:
+            match = re.search(r"<\|s_(\d+)\|>", buffer)
+            if not match:
+                break
+            token_num = int(match.group(1))
+            final_id = token_num + ORIGINAL_VOCAB_SIZE
+            yield final_id
+            buffer = buffer[match.end():]
+
+    def forward_audio_tokenizer(self, wav, wav_len):
+        """Forward pass through the audio tokenizer component.
+
+        Args:
+            wav: Input waveform tensor
+            wav_len: Waveform length tensor
+
+        Returns:
+            Tuple of global and semantic tokens
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='audio_tokenizer',
+            requested_output_names=['prompt_speech_tokens'],
+            inputs=[wav, wav_len]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
+        prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
+
+        return prompt_speech_tokens
+
+    def forward_speaker_embedding(self, wav):
+        """Forward pass through the speaker embedding component.
+
+        Args:
+            wav: Input waveform tensor
+
+        Returns:
+            Prompt speaker embedding tensor
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='speaker_embedding',
+            requested_output_names=['prompt_spk_embedding'],
+            inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
+        prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
+
+        return prompt_spk_embedding
+
+    async def forward_token2wav(
+            self,
+            index: int,
+            target_speech_tokens: torch.Tensor,
+            request_id: str,
+            reference_wav: object,
+            reference_wav_len: object,
+            finalize: bool = None) -> torch.Tensor:
+        """Forward pass through the vocoder component.
+
+        Args:
+            index: Index of the request
+            target_speech_tokens: Target speech tokens tensor
+            request_id: Request ID
+            reference_wav: Reference waveform tensor
+            reference_wav_len: Reference waveform length tensor
+            finalize: Whether to finalize the request
+
+        Returns:
+            Generated waveform tensor
+        """
+        target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
+        finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
+        inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
+
+        # Create and execute inference request
+        inference_request = pb_utils.InferenceRequest(
+            model_name='token2wav_dit',
+            requested_output_names=[
+                "waveform",
+            ],
+            inputs=inputs_tensor,
+            request_id=request_id,
+            parameters={"priority": index + 1},
+        )
+
+        inference_response = await inference_request.async_exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output waveform
+        waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
+        waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
+
+        return waveform
+
+    def _extract_speech_feat(self, speech):
+        speech_feat = mel_spectrogram(
+            speech,
+            n_fft=1920,
+            num_mels=80,
+            sampling_rate=24000,
+            hop_size=480,
+            win_size=1920,
+            fmin=0,
+            fmax=8000).squeeze(
+            dim=0).transpose(
+            0,
+            1).to(
+                self.device)
+        speech_feat = speech_feat.unsqueeze(dim=0)
+        return speech_feat
+
+    async def _process_request(self, request):
+        request_id = request.request_id()
+
+        reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+        reference_text = reference_text[0][0].decode('utf-8')
+
+        wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
+        wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+
+        if reference_text not in self.speaker_cache:
+            self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
+        prompt_speech_tokens = self.speaker_cache[reference_text]
+
+        target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
+        target_text = target_text[0][0].decode('utf-8')
+
+        if self.decoupled:
+            response_sender = request.get_response_sender()
+
+            semantic_token_ids_arr = []
+            token_offset, chunk_index = 0, 0
+            start_time = time.time()
+            this_token_hop_len = self.token_hop_len
+            async for generated_ids in self.forward_llm_async(
+                target_text=target_text,
+                reference_text=reference_text,
+                prompt_speech_tokens=prompt_speech_tokens,
+            ):
+                if not generated_ids:
+                    break
+                semantic_token_ids_arr.append(generated_ids)
+                while True:
+                    pending_num = len(semantic_token_ids_arr) - token_offset
+                    if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
+                        this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
+                        this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
+                        sub_tts_speech = await self.forward_token2wav(
+                            chunk_index,
+                            this_tts_speech_token, request_id, wav, wav_len, False
+                        )
+                        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+                        inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                        response_sender.send(inference_response)
+
+                        token_offset += this_token_hop_len
+
+                        if self.dynamic_chunk_strategy == "exponential":
+                            this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
+                        elif self.dynamic_chunk_strategy == "equal":
+                            this_token_hop_len = self.token_hop_len
+                        elif self.dynamic_chunk_strategy == "time_based":
+                            # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
+                            cost_time = time.time() - start_time
+                            duration = token_offset / self.token_frame_rate
+                            if chunk_index > 0 and cost_time > 0:
+                                avg_chunk_processing_time = cost_time / (chunk_index + 1)
+                                if avg_chunk_processing_time > 0:
+                                    multiples = (duration - cost_time) / avg_chunk_processing_time
+                                    next_pending_num = len(semantic_token_ids_arr) - token_offset
+                                    if multiples > 4:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
+                                    elif multiples > 2:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
+                                    else:
+                                        this_token_hop_len = self.token_hop_len
+                                    this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
+                        chunk_index += 1
+                    else:
+                        break
+
+            this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
+            sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
+            audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+            inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+            response_sender.send(inference_response)
+
+            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+        else:
+            raise NotImplementedError("Offline TTS mode is not supported")
+
+    async def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated audio
+        """
+        tasks = [
+            asyncio.create_task(self._process_request(request))
+            for request in requests
+        ]
+        await asyncio.gather(*tasks)
+        return None
+
+    def finalize(self):
+        self.logger.log_info("Finalizing CosyVoice DIT model")
+        if hasattr(self, "http_client"):
+            asyncio.run(self.http_client.aclose())

+ 73 - 0
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt

@@ -0,0 +1,73 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "cosyvoice2_dit"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+model_transaction_policy {
+  decoupled: ${decoupled_mode}
+}
+parameters [
+  {
+   key: "llm_tokenizer_dir",
+   value: {string_value:"${llm_tokenizer_dir}"}
+  },
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+    optional: true
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+    optional: true
+  },
+  {
+    name: "reference_text"
+    data_type: TYPE_STRING
+    dims: [1]
+    optional: true
+  },
+  {
+    name: "target_text"
+    data_type: TYPE_STRING
+    dims: [1]
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: ${bls_instance_num}
+    kind: KIND_CPU
+  }
+]

+ 0 - 1
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -28,7 +28,6 @@ import json
 import os
 
 import logging
-from typing import List, Dict
 
 import torch
 from torch.utils.dlpack import to_dlpack

+ 142 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py

@@ -0,0 +1,142 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import os
+
+import logging
+from typing import List, Dict
+
+import torch
+from torch.utils.dlpack import to_dlpack
+from torch.nn import functional as F
+
+import triton_python_backend_utils as pb_utils
+
+from hyperpyyaml import load_hyperpyyaml
+from cosyvoice.utils.common import fade_in_out
+from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
+from cosyvoice.utils.common import TrtContextWrapper
+from collections import defaultdict
+import numpy as np
+from .token2wav_dit import CosyVoice2_Token2Wav
+import hashlib
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+
+ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
+
+
+def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
+    """
+    Generates a unique ID for a torch.Tensor.
+    Tensors with the same elements and properties will have the same ID.
+    """
+    # Convert tensor to a byte string
+    tensor_bytes = tensor.numpy().tobytes()
+
+    # Create a SHA-256 hash of the byte string
+    hasher = hashlib.sha256()
+    hasher.update(tensor_bytes)
+
+    return hasher.hexdigest()
+
+
+class TritonPythonModel:
+    """Triton Python model for vocoder.
+
+    This model takes global and semantic tokens as input and generates audio waveforms
+    using the BiCodec vocoder.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        # Parse model parameters
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {key: value["string_value"] for key, value in parameters.items()}
+        model_dir = model_params["model_dir"]
+
+        # Initialize device and vocoder
+        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+        logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
+
+        # FIXME: device id settings
+        self.token2wav_model = CosyVoice2_Token2Wav(
+            model_dir, enable_trt=True, streaming=True
+        )
+        logger.info("Token2Wav initialized successfully")
+
+    def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated waveforms
+        """
+        responses = []
+        # Process each request in batch
+        for request in requests:
+            target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
+            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
+            target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
+            target_speech_tokens = target_speech_tokens.squeeze().tolist()
+
+            finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
+
+            request_id = request.request_id()
+
+            wav_array = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav").as_numpy()
+            wav_len = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav_len").as_numpy().item()
+
+            wav_array = torch.from_numpy(wav_array)
+            wav = wav_array[:, :wav_len].squeeze(0)
+
+            spk_id = get_spk_id_from_prompt_audio(wav)
+
+            audio_hat = self.token2wav_model.forward_streaming(
+                target_speech_tokens, finalize, request_id=request_id,
+                speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
+            )
+
+            outputs = []
+
+            wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
+            outputs.append(wav_tensor)
+            inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
+            responses.append(inference_response)
+
+        return responses

+ 510 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py

@@ -0,0 +1,510 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 token2wav.py --enable-trt || exit 1
+"""
+import torch
+# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
+from flashcosyvoice.modules.hifigan import HiFTGenerator
+from flashcosyvoice.utils.audio import mel_spectrogram
+import torchaudio.compliance.kaldi as kaldi
+import onnxruntime
+import s3tokenizer
+from torch.utils.data import DataLoader
+from datasets import load_dataset
+import torchaudio
+import os
+import logging
+import argparse
+import queue
+import time
+import numpy as np
+from hyperpyyaml import load_hyperpyyaml
+
+
+def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
+    """perform fade_in_out in tensor style
+    """
+    mel_overlap_len = int(window.shape[0] / 2)
+    fade_in_mel = fade_in_mel.clone()
+    fade_in_mel[..., :mel_overlap_len] = \
+        fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
+        fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
+    return fade_in_mel
+
+
+def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
+    import tensorrt as trt
+    logging.info("Converting onnx to trt...")
+    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+    logger = trt.Logger(trt.Logger.INFO)
+    builder = trt.Builder(logger)
+    network = builder.create_network(network_flags)
+    parser = trt.OnnxParser(network, logger)
+    config = builder.create_builder_config()
+    # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB
+    if dtype == torch.float16:
+        config.set_flag(trt.BuilderFlag.FP16)
+
+    profile = builder.create_optimization_profile()
+    # load onnx model
+    with open(onnx_model, "rb") as f:
+        if not parser.parse(f.read()):
+            for error in range(parser.num_errors):
+                print(parser.get_error(error))
+            raise ValueError('failed to parse {}'.format(onnx_model))
+    # set input shapes
+    for i in range(len(trt_kwargs['input_names'])):
+        profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
+    if dtype == torch.float16:
+        tensor_dtype = trt.DataType.HALF
+    elif dtype == torch.bfloat16:
+        tensor_dtype = trt.DataType.BF16
+    elif dtype == torch.float32:
+        tensor_dtype = trt.DataType.FLOAT
+    else:
+        raise ValueError('invalid dtype {}'.format(dtype))
+    # set input and output data type
+    for i in range(network.num_inputs):
+        input_tensor = network.get_input(i)
+        input_tensor.dtype = tensor_dtype
+    for i in range(network.num_outputs):
+        output_tensor = network.get_output(i)
+        output_tensor.dtype = tensor_dtype
+    config.add_optimization_profile(profile)
+    engine_bytes = builder.build_serialized_network(network, config)
+    # save trt engine
+    with open(trt_model, "wb") as f:
+        f.write(engine_bytes)
+    logging.info("Succesfully convert onnx to trt...")
+
+
+class TrtContextWrapper:
+    def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        self.trt_engine = trt_engine
+        self.device = device
+        for _ in range(trt_concurrent):
+            trt_context = trt_engine.create_execution_context()
+            trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
+            assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
+            self.trt_context_pool.put([trt_context, trt_stream])
+        assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
+
+    def acquire_estimator(self):
+        return self.trt_context_pool.get(), self.trt_engine
+
+    def release_estimator(self, context, stream):
+        self.trt_context_pool.put([context, stream])
+
+
+class CosyVoice2_Token2Wav(torch.nn.Module):
+    def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
+        super().__init__()
+        self.device_id = device_id
+        self.device = f"cuda:{device_id}"
+        with open(f"{model_dir}/flow.yaml", "r") as f:
+            configs = load_hyperpyyaml(f)
+            self.flow = configs['flow']
+
+        self.dtype = dtype
+        self.flow.to(self.dtype)
+
+        self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
+        self.flow.to(self.device).eval()
+
+        self.hift = HiFTGenerator()
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        self.spk_model = onnxruntime.InferenceSession(
+            f"{model_dir}/campplus.onnx", sess_options=option,
+            providers=["CPUExecutionProvider"])
+        self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
+
+        gpu = "l20"
+        if enable_trt:
+            if streaming:
+                self.load_trt(
+                    f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
+                    f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
+                    1,
+                    self.dtype, streaming
+                )
+            else:
+                self.load_trt(
+                    f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
+                    f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
+                    1,
+                    self.dtype
+                )
+            self.load_spk_trt(
+                f'{model_dir}/campplus.{gpu}.fp32.trt',
+                f'{model_dir}/campplus.onnx',
+                1,
+                False
+            )
+
+        self.streaming_flow_cache = {}
+        self.speaker_cache = {}
+
+        self.mel_cache_len = 8  # hard-coded, 160ms
+        self.source_cache_len = int(self.mel_cache_len * 480)   # 50hz mel -> 24kHz wave
+        self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
+
+        # hifigan cache for streaming tts
+        self.hift_cache_dict = {}
+
+    def forward_spk_embedding(self, spk_feat):
+        if isinstance(self.spk_model, onnxruntime.InferenceSession):
+            return self.spk_model.run(
+                None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
+            )[0].flatten().tolist()
+        else:
+            [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
+            # NOTE need to synchronize when switching stream
+            with torch.cuda.device(self.device_id):
+                torch.cuda.current_stream().synchronize()
+                spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
+                batch_size = spk_feat.size(0)
+
+                with stream:
+                    spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
+                    output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
+
+                    data_ptrs = [spk_feat.contiguous().data_ptr(),
+                                 output_tensor.contiguous().data_ptr()]
+                    for i, j in enumerate(data_ptrs):
+
+                        spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                    # run trt engine
+                    assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                    torch.cuda.current_stream().synchronize()
+                self.spk_model.release_estimator(spk_model, stream)
+
+            return output_tensor.cpu().numpy().flatten().tolist()
+
+    def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
+        if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
+            trt_kwargs = self.get_spk_trt_kwargs()
+            convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
+        import tensorrt as trt
+        with open(spk_model, 'rb') as f:
+            spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
+        self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_spk_trt_kwargs(self):
+        min_shape = [(1, 4, 80)]
+        opt_shape = [(1, 500, 80)]
+        max_shape = [(1, 3000, 80)]
+        input_names = ["input"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
+        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
+        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
+            opt_batch_size = 2
+            max_batch_size = 16
+            if streaming:
+                opt_batch_size, max_batch_size = 1, 1  # only support batch size 1 for streaming tts
+            trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
+            convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
+        if streaming:
+            min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
+            opt_shape = [
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
+                (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
+                (16, opt_batch_size * 2, 8, 100, 128)
+            ]
+            max_shape = [
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
+                (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
+                (16, max_batch_size * 2, 8, 1000, 128)
+            ]
+            input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
+        else:
+            min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
+            opt_shape = [
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
+            ]
+            max_shape = [
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
+            ]
+            input_names = ["x", "mask", "mu", "cond", "t", "spks"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
+        prompt_speech_tokens_list, prompt_speech_mels_list = [], []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            log_mel = s3tokenizer.log_mel_spectrogram(audio)  # [num_mels, T]
+            prompt_speech_mels_list.append(log_mel)
+        prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
+        prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
+            prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
+        )
+        for i in range(len(prompt_speech_tokens)):
+            speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
+            prompt_speech_tokens_list.append(speech_tokens_i)
+        return prompt_speech_tokens_list
+
+    def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
+        spk_emb_for_flow = []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
+            spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
+            spk_emb = self.forward_spk_embedding(spk_feat)
+
+            spk_emb_for_flow.append(spk_emb)
+        spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
+        if self.dtype != torch.float32:
+            spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
+        return spk_emb_for_flow
+
+    def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
+        prompt_mels_for_flow = []
+        prompt_mels_lens_for_flow = []
+        for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
+            assert len(audio.shape) == 1
+            audio = audio.unsqueeze(0)
+            if sample_rate != 24000:
+                audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
+            mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0)  # [T, num_mels]
+            mel_len = mel.shape[0]
+            prompt_mels_for_flow.append(mel)
+            prompt_mels_lens_for_flow.append(mel_len)
+        prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
+            prompt_mels_for_flow, batch_first=True, padding_value=0
+        )  # [B, T', num_mels=80]
+        prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
+        return prompt_mels_for_flow, prompt_mels_lens_for_flow
+
+    def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
+                     generated_speech_tokens_list: list[list[int]],
+                     prompt_mels_for_flow: torch.Tensor,
+                     prompt_mels_lens_for_flow: torch.Tensor,
+                     spk_emb_for_flow: torch.Tensor):
+        batch_size = prompt_mels_for_flow.shape[0]
+        flow_inputs = []
+        flow_inputs_lens = []
+        for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
+            flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
+            flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
+
+        flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
+        flow_inputs_lens = torch.tensor(flow_inputs_lens)
+
+        with torch.amp.autocast(self.device, dtype=torch.float16):
+            generated_mels, generated_mels_lens = self.flow.inference(
+                flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
+                prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
+            )
+
+        return generated_mels, generated_mels_lens
+
+    def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
+        batch_size = generated_mels.shape[0]
+        generated_wavs = []
+        for i in range(batch_size):
+            mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
+            wav, _ = self.hift(speech_feat=mel)
+            generated_wavs.append(wav)
+        return generated_wavs
+
+    @torch.inference_mode()
+    def forward(
+        self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
+    ):
+        assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
+
+        prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
+
+        generated_mels, generated_mels_lens = self.forward_flow(
+            prompt_speech_tokens_list, generated_speech_tokens_list,
+            prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+        )
+
+        generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
+        return generated_wavs
+
+    def prepare_prompt_audio(
+        self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
+    ):
+        assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
+
+        prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
+
+        prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
+
+        spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
+        return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+
+    def get_prompt_audio_cache_for_streaming_tts(
+        self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+    ):
+        assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
+        for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
+            prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
+        prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
+
+        cache = self.flow.setup_cache(
+            prompt_speech_tokens_tensor.to(self.device),
+            prompt_mels_for_flow.to(self.device),
+            spk_emb_for_flow.to(self.device),
+            n_timesteps=10
+        )
+        new_cache = {k: v.clone() for k, v in cache.items()}
+        # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
+        return new_cache
+
+    @torch.inference_mode()
+    def forward_streaming(
+        self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
+    ):
+        if speaker_id not in self.speaker_cache:
+            assert prompt_audio is not None, "prompt_audio is required for new speaker"
+            assert prompt_audio_sample_rate == 16000
+
+            prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
+
+            token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
+            prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
+            prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
+
+            prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
+
+            cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
+            self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
+
+        if request_id not in self.streaming_flow_cache:
+            self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
+            self.hift_cache_dict[request_id] = dict(
+                mel=torch.zeros(1, 80, 0, device='cuda'),
+                source=torch.zeros(1, 1, 0, device='cuda'),
+                speech=torch.zeros(1, 0, device='cuda'),
+            )
+
+        current_request_cache = self.streaming_flow_cache[request_id]
+
+        current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
+        generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
+
+        chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
+            token=generated_speech_tokens,
+            spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
+            cache=current_request_cache,
+            last_chunk=last_chunk,
+            n_timesteps=10,
+        )
+
+        self.streaming_flow_cache[request_id] = new_streaming_flow_cache
+
+        if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
+            self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
+                self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
+                self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
+            ], dim=4)
+
+        hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
+        hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
+        hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
+        mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
+
+        speech, source = self.hift(mel, hift_cache_source)
+
+        # overlap speech smooth
+        if hift_cache_speech.shape[-1] > 0:
+            speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
+
+        # update vocoder cache
+        self.hift_cache_dict[request_id] = dict(
+            mel=mel[..., -self.mel_cache_len:].clone().detach(),
+            source=source[:, :, -self.source_cache_len:].clone().detach(),
+            speech=speech[:, -self.source_cache_len:].clone().detach(),
+        )
+        if not last_chunk:
+            speech = speech[:, :-self.source_cache_len]
+
+        if last_chunk:
+            assert request_id in self.streaming_flow_cache
+            self.streaming_flow_cache.pop(request_id)
+            self.hift_cache_dict.pop(request_id)
+
+        return speech
+
+
+def collate_fn(batch):
+    ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
+    for item in batch:
+        generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
+        audio = torch.from_numpy(item['prompt_audio']['array']).float()
+        prompt_audios_list.append(audio)
+        prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
+        ids.append(item['id'])
+
+    return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--enable-trt", action="store_true")
+    parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
+    parser.add_argument("--batch-size", type=int, default=1)
+    parser.add_argument("--output-dir", type=str, default="generated_wavs")
+    parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
+    parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = get_args()
+    model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    dataset_name = "yuekai/seed_tts_cosy2"
+
+    dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
+
+    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
+
+    for _ in range(args.warmup):
+        start_time = time.time()
+        for batch in data_loader:
+            ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
+
+            generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
+
+            for id, wav in zip(ids, generated_wavs):
+                torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
+        end_time = time.time()
+        epoch_time = end_time - start_time
+        print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

+ 69 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt

@@ -0,0 +1,69 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "token2wav_dit"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+    priority_levels: 10
+    default_priority_level: 10
+}
+
+parameters [
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "target_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+  },
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+  },
+  {
+    name: "finalize"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 99 - 10
runtime/triton_trtllm/offline_inference.py

@@ -28,7 +28,6 @@ import argparse
 import json
 import os
 import sys
-from pathlib import Path
 
 import torch
 import torch.distributed as dist
@@ -43,8 +42,9 @@ import soundfile as sf
 import s3tokenizer
 from functools import partial
 import time
-
-from token2wav import CosyVoice2_Token2Wav
+import requests
+import asyncio
+import httpx
 
 sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
 try:
@@ -53,6 +53,32 @@ except RuntimeError:
     pass
 
 
+async def send_request_async(client, url, payload):
+    response = await client.post(url, json=payload, timeout=None)
+    response.raise_for_status()
+    response_json = response.json()
+    return response_json['choices'][0]['message']['content']
+
+
+async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
+    async with httpx.AsyncClient() as client:
+        tasks = []
+        for chat in chats:
+            payload = {
+                "model": model_name,
+                "messages": chat,
+                "max_tokens": 2048,
+                "temperature": temperature,
+                "top_p": top_p,
+                "top_k": top_k,
+                "repetition_penalty": 1.1,
+                "stop": ["<|eos1|>", "<|eos|>"],
+                "stream": False,
+            }
+            tasks.append(send_request_async(client, api_base, payload))
+        return await asyncio.gather(*tasks)
+
+
 def extract_speech_ids(speech_tokens_str):
     """Extract speech IDs from token strings like <|s_23456|>"""
     speech_ids = []
@@ -149,7 +175,7 @@ def get_args():
         "--backend",
         type=str,
         default="hf",
-        choices=["hf", "trtllm", "vllm"],
+        choices=["hf", "trtllm", "vllm", "trtllm-serve"],
         help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
     )
     parser.add_argument(
@@ -164,6 +190,18 @@ def get_args():
         default=0.6,
         help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
     )
+    parser.add_argument(
+        "--openai-api-base",
+        type=str,
+        default="http://localhost:8000/v1/chat/completions",
+        help="OpenAI API base URL (for trtllm-serve backend)",
+    )
+    parser.add_argument(
+        "--openai-model-name",
+        type=str,
+        default="trt_engines_bfloat16",
+        help="Model name to use with OpenAI API (for trtllm-serve backend)",
+    )
     args = parser.parse_args()
     return args
 
@@ -180,6 +218,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
     input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
     prompt_text_after_apply_template_list = []
     mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
+    chat_list = []
     for _, item in enumerate(batch):
         audio_processing_start_time = time.time()
         prompt_text, target_text = (
@@ -237,6 +276,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
             {"role": "user", "content": full_text_list[i]},
             {"role": "assistant", "content": prompt_audio_cosy2_id_str}
         ]
+        chat_list.append(chat)
 
         assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
 
@@ -265,6 +305,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
         "audio_processing_time": total_audio_processing_time,
         "speech_tokenization_time": total_speech_tokenization_time,
         "text_tokenization_time": total_text_tokenization_time,
+        "chat_list": chat_list
     }
 
 
@@ -318,9 +359,16 @@ def main(args):
     elif args.backend == "vllm":
         model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
         runner = None
+    elif args.backend == "trtllm-serve":
+        model = None
+        runner = None
     else:
         raise ValueError(f"Unsupported backend: {args.backend}")
-
+    if 'Step-Audio-2-mini' in args.token2wav_path:
+        from token2wav_dit import CosyVoice2_Token2Wav
+    else:
+        assert 'CosyVoice2-0.5B' in args.token2wav_path
+        from token2wav import CosyVoice2_Token2Wav
     token2wav_model = CosyVoice2_Token2Wav(
         model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
     )
@@ -452,6 +500,35 @@ def main(args):
                     print(outputs)
                     for j, output in enumerate(outputs):
                         outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
+                elif args.backend == "trtllm-serve":
+                    if args.batch_size > 1:
+                        outputs = asyncio.run(send_batch_requests_async(
+                            args.openai_api_base,
+                            args.openai_model_name,
+                            batch["chat_list"],
+                            args.temperature,
+                            args.top_p,
+                            args.top_k,
+                        ))
+                    else:
+                        outputs = []
+                        for chat in batch["chat_list"]:
+                            payload = {
+                                "model": args.openai_model_name,
+                                "messages": chat,
+                                "max_tokens": 2048,
+                                "temperature": args.temperature,
+                                "top_p": args.top_p,
+                                "top_k": args.top_k,
+                                "repetition_penalty": 1.1,
+                                "stop": ["<|eos1|>", "<|eos|>"],
+                                "stream": False,
+                            }
+                            response = requests.post(args.openai_api_base, json=payload)
+                            response.raise_for_status()
+                            response_json = response.json()
+                            generated_content = response_json['choices'][0]['message']['content']
+                            outputs.append(generated_content)
 
                 llm_end_time = time.time()
                 total_llm_time += (llm_end_time - llm_start_time)
@@ -459,10 +536,21 @@ def main(args):
                 items_for_token_2wav = []
                 for i in range(len(batch["ids"])):
                     llm_post_processing_start_time = time.time()
-                    input_length = len(batch["input_ids"][i])
-                    generated_ids = outputs[i][input_length:]
-                    speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
-                    speech_ids = extract_speech_ids(speech_tokens_str)
+                    if args.backend == "trtllm-serve":
+                        speech_tokens_str = outputs[i].strip().split('><')
+                        if len(speech_tokens_str) > 1:
+                            speech_tokens_str = [
+                                t if t.startswith('<') else '<' + t for t in speech_tokens_str
+                            ]
+                            speech_tokens_str = [
+                                t if t.endswith('>') else t + '>' for t in speech_tokens_str
+                            ]
+                        speech_ids = extract_speech_ids(speech_tokens_str)
+                    else:
+                        input_length = len(batch["input_ids"][i])
+                        generated_ids = outputs[i][input_length:]
+                        speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+                        speech_ids = extract_speech_ids(speech_tokens_str)
                     print(i, speech_ids)
                     if len(speech_ids) == 0:
                         print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
@@ -502,7 +590,6 @@ def main(args):
                         t2w_prompt_audios_list,
                         t2w_prompt_audios_sample_rate,
                     )
-                    torch.cuda.synchronize()
                     token2wav_end_time = time.time()
                     total_token2wav_time += (token2wav_end_time - token2wav_start_time)
 
@@ -558,6 +645,8 @@ if __name__ == "__main__":
         from tensorrt_llm.runtime import ModelRunnerCpp
     elif args.backend == "hf":
         from transformers import AutoModelForCausalLM
+    elif args.backend == "trtllm-serve":
+        pass
     else:
         raise ValueError(f"Unsupported backend: {args.backend}")
     main(args)

+ 225 - 0
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh

@@ -0,0 +1,225 @@
+#!/bin/bash
+# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
+export CUDA_VISIBLE_DEVICES=0
+cosyvoice_path=/workspace/CosyVoice
+stepaudio2_path=/workspace/Step-Audio2
+
+export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
+export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
+export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
+
+stage=$1
+stop_stage=$2
+
+huggingface_model_local_dir=./cosyvoice2_llm
+model_scope_model_local_dir=./CosyVoice2-0.5B
+step_audio_model_dir=./Step-Audio-2-mini
+
+trt_dtype=bfloat16
+trt_weights_dir=./trt_weights_${trt_dtype}
+trt_engines_dir=./trt_engines_${trt_dtype}
+
+model_repo=./model_repo_cosyvoice2_dit
+bls_instance_num=10
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+
+    echo "Cloning Step-Audio2-mini"
+    git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
+
+    echo "Cloning CosyVoice"
+    git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
+    cd $cosyvoice_path
+    git submodule update --init --recursive
+    cd runtime/triton_trtllm
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+    echo "Downloading CosyVoice2-0.5B"
+    # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
+    huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
+    modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
+
+    echo "Step-Audio2-mini"
+    huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
+    cd $step_audio_model_dir/token2wav
+    wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
+    wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
+    cd -
+fi
+
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+    echo "Converting checkpoint to TensorRT weights"
+    python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
+                                --output_dir $trt_weights_dir \
+                                --dtype $trt_dtype || exit 1
+
+    echo "Building TensorRT engines"
+    trtllm-build --checkpoint_dir $trt_weights_dir \
+                --output_dir $trt_engines_dir \
+                --max_batch_size 64 \
+                --max_num_tokens 32768 \
+                --gemm_plugin $trt_dtype || exit 1
+
+    echo "Testing TensorRT engines"
+    python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
+                    --tokenizer_dir $huggingface_model_local_dir \
+                    --top_k 50 --top_p 0.95 --temperature 0.8 \
+                    --engine_dir=$trt_engines_dir  || exit 1
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+    echo "Creating model repository async mode"
+    rm -rf $model_repo
+    mkdir -p $model_repo
+    cosyvoice2_dir="cosyvoice2_dit"
+    token2wav_dir="token2wav_dit"
+
+    cp -r ./model_repo/${cosyvoice2_dir} $model_repo
+    cp -r ./model_repo/${token2wav_dir} $model_repo
+    cp -r ./model_repo/audio_tokenizer $model_repo
+    cp -r ./model_repo/speaker_embedding $model_repo
+
+
+    ENGINE_PATH=$trt_engines_dir
+    MAX_QUEUE_DELAY_MICROSECONDS=0
+    MODEL_DIR=$model_scope_model_local_dir
+    LLM_TOKENIZER_DIR=$huggingface_model_local_dir
+    BLS_INSTANCE_NUM=$bls_instance_num
+    TRITON_MAX_BATCH_SIZE=1
+    DECOUPLED_MODE=True # Only streaming TTS mode is supported using Nvidia Triton for now
+    STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
+
+    python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+   echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
+   mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64  --kv_cache_free_gpu_memory_fraction 0.4 &
+   tritonserver --model-repository $model_repo --http-port 18000 &
+   wait
+    # Test using curl
+    # curl http://localhost:8000/v1/chat/completions \
+    #     -H "Content-Type: application/json" \
+    #     -d '{
+    #         "model": "",
+    #         "messages":[{"role": "user", "content": "Where is New York?"},
+    #                     {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
+    #         "max_tokens": 512,
+    #         "temperature": 0.8,
+    #         "top_p": 0.95,
+    #         "top_k": 50,
+    #         "stop": ["<|eos1|>"],
+    #         "repetition_penalty": 1.2,
+    #         "stream": false
+    #     }'
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+    echo "Running benchmark client"
+    num_task=4
+    mode=streaming
+    BLS_INSTANCE_NUM=$bls_instance_num
+
+    python3 client_grpc.py \
+        --server-addr localhost \
+        --server-port 8001 \
+        --model-name cosyvoice2_dit \
+        --num-tasks $num_task \
+        --mode $mode \
+        --huggingface-dataset yuekai/seed_tts_cosy2 \
+        --log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
+
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
+
+  datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
+  backend=trtllm # hf, trtllm, vllm, trtllm-serve
+
+  batch_sizes=(16)
+  token2wav_batch_size=1
+
+  for batch_size in ${batch_sizes[@]}; do
+    for dataset in ${datasets[@]}; do
+    output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
+    CUDA_VISIBLE_DEVICES=1 \
+        python3 offline_inference.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $step_audio_model_dir/token2wav \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+    done
+  done
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+   echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
+   export CUDA_VISIBLE_DEVICES=1
+   # Note: Using pre-computed cosyvoice2 tokens
+   python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
+   # Offline Token2wav inference
+   python3 token2wav_dit.py --enable-trt
+fi
+
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+   echo "Disaggregated Server: LLM and Token2wav on different GPUs"
+   echo "Starting LLM server on GPU 0"
+   export CUDA_VISIBLE_DEVICES=0
+   mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64  --kv_cache_free_gpu_memory_fraction 0.4 &
+   echo "Starting Token2wav server on GPUs 1-3"
+   Token2wav_num_gpus=3
+   http_port=17000
+   grpc_port=18000
+   metrics_port=16000
+   for i in $(seq 0 $(($Token2wav_num_gpus - 1))); do
+       echo "Starting server on GPU $i"
+       http_port=$((http_port + 1))
+       grpc_port=$((grpc_port + 1))
+       metrics_port=$((metrics_port + 1))
+       # Two instances of Token2wav server on the same GPU
+       CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
+       http_port=$((http_port + 1))
+       grpc_port=$((grpc_port + 1))
+       metrics_port=$((metrics_port + 1))
+       CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
+   done
+   wait
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+    echo "Running benchmark client for Disaggregated Server"
+    per_gpu_instances=2
+    mode=streaming
+    BLS_INSTANCE_NUM=$bls_instance_num
+    Token2wav_num_gpus=(1 2 3)
+    concurrent_tasks=(1 2 3 4 5 6)
+    for n_gpu in ${Token2wav_num_gpus[@]}; do
+        echo "Test 1 GPU for LLM server and $n_gpu GPUs for Token2wav servers"
+        for concurrent_task in ${concurrent_tasks[@]}; do
+            num_instances=$((per_gpu_instances * n_gpu))
+            for i in $(seq 1 $num_instances); do
+                port=$(($i + 18000))
+                python3 client_grpc.py \
+                    --server-addr localhost \
+                    --server-port $port \
+                    --model-name cosyvoice2_dit \
+                    --num-tasks $concurrent_task \
+                    --mode $mode \
+                    --huggingface-dataset yuekai/seed_tts_cosy2 \
+                    --log-dir ./log_disagg_concurrent_tasks_${concurrent_task}_per_instance_total_token2wav_instances_${num_instances}_port_${port} &
+            done
+            wait
+        done
+    done
+fi

+ 0 - 5
runtime/triton_trtllm/scripts/test_llm.py

@@ -15,11 +15,6 @@
 # limitations under the License.
 
 import argparse
-import ast
-import csv
-import os
-from pathlib import Path
-from typing import List, Optional
 
 import numpy as np
 import torch

+ 122 - 0
runtime/triton_trtllm/streaming_inference.py

@@ -0,0 +1,122 @@
+import torch
+import os
+import argparse
+from datasets import load_dataset
+from torch.utils.data import DataLoader
+import numpy as np
+import torchaudio
+import time
+from token2wav_dit import CosyVoice2_Token2Wav
+import soundfile as sf
+
+
+def collate_fn(batch):
+    ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
+    prompt_speech_tokens_list, prompt_text_list = [], []
+    for item in batch:
+        generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
+        audio = torch.from_numpy(item['prompt_audio']['array']).float()
+        prompt_audios_list.append(audio)
+        prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
+        ids.append(item['id'])
+        prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
+        prompt_text_list.append(item['prompt_text'])
+
+    return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--enable-trt", action="store_true")
+    parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
+    parser.add_argument("--batch-size", type=int, default=1)
+    parser.add_argument("--output-dir", type=str, default="generated_wavs")
+    parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
+    parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
+    parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = get_args()
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+
+    dataset_name = args.dataset_name
+    dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
+    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
+
+    token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
+
+    CHUNK_SIZE = 25
+    token_frame_rate = 25
+    OVERLAP_SIZE = 0
+
+    warmup_times = 3
+    for _ in range(warmup_times):
+        start_time = time.time()
+        total_forward_count = 0
+        for batch in data_loader:
+            tts_speech_list = []
+            ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
+
+            id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
+
+            assert prompt_audio_sample_rate == 16000
+
+            prompt_text = prompt_text_list[0]
+            prompt_speech_tokens = prompt_speech_tokens_list[0]
+
+            semantic_token_ids_arr, token_offset = [], 0
+            flow_prompt_speech_token_len = len(prompt_speech_tokens)
+
+            buffer = generated_speech_tokens
+            output_wavs = []
+            chunk_index = 0
+            while True:
+                if args.strategy == "equal":
+                    this_chunk_size = CHUNK_SIZE
+                elif args.strategy == "exponential":
+                    this_chunk_size = token_frame_rate * (2 ** chunk_index)
+
+                if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
+                    wavs = token2wav_model.forward_streaming(
+                        buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
+                        False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
+                        prompt_audio_sample_rate=prompt_audio_sample_rate
+                    )
+                    buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
+
+                    output_wavs.append(wavs)
+                    total_forward_count += 1
+                    chunk_index += 1
+
+                else:
+                    wavs = token2wav_model.forward_streaming(
+                        buffer, True, request_id=id, speaker_id=f"{id}",
+                        prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
+                    )
+                    output_wavs.append(wavs)
+                    total_forward_count += 1
+                    # chunk_index += 1
+                    break
+
+            for i, wav in enumerate(output_wavs):
+                output_wavs[i] = wav.cpu().numpy().squeeze()
+
+            audios = output_wavs
+            reconstructed_audio = np.concatenate(audios)
+            sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
+
+        end_time = time.time()
+
+        if _ == 0:
+            token2wav_model.speaker_cache = {}
+            print(f"Warmup time: {end_time - start_time} seconds")
+            print("clear speaker cache")
+        elif _ == 1:
+            print(f"Cost time without speaker cache: {end_time - start_time} seconds")
+        else:
+            print(f"Cost time with speaker cache: {end_time - start_time} seconds")
+            print(f"Total flow matching forward calls: {total_forward_count}")

+ 1 - 0
runtime/triton_trtllm/token2wav_dit.py

@@ -0,0 +1 @@
+model_repo/token2wav_dit/1/token2wav_dit.py