Browse Source

Merge pull request #1598 from yuekaizhang/streaming

[Runtime] StepAudio2 Streaming DiT Token2Wav Integration
Xiang Lyu 6 months ago
parent
commit
6e01309e01

+ 2 - 2
examples/grpo/cosyvoice2/README.md

@@ -36,7 +36,7 @@ Stage `0` converts raw JSONL files into the parquet format expected by veRL:
 ```bash
 ```bash
 bash run.sh 0 0
 bash run.sh 0 0
 ```
 ```
-Create two JSONL files—`train.jsonl` and `test.jsonl`.  
+Create two JSONL files—`train.jsonl` and `test.jsonl`.
 The script will then generate two Parquet files:
 The script will then generate two Parquet files:
 
 
 ```
 ```
@@ -111,7 +111,7 @@ bash run.sh 5 5
 
 
 The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
 The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
 > [!TIP]
 > [!TIP]
->  However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. 
+>  However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
 
 
 ## Results
 ## Results
 
 

+ 1 - 1
examples/grpo/cosyvoice2/infer_dataset.py

@@ -53,7 +53,7 @@ except RuntimeError:
     pass
     pass
 
 
 
 
-TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
+TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"  # noqa: E501
 
 
 
 
 def audio_decode_cosyvoice2(
 def audio_decode_cosyvoice2(

+ 0 - 2
examples/grpo/cosyvoice2/pretrained_to_huggingface.py

@@ -1,5 +1,3 @@
-#!/usr/bin/env python3
-
 # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 # SPDX-License-Identifier: Apache-2.0
 # SPDX-License-Identifier: Apache-2.0
 #
 #

+ 3 - 3
examples/grpo/cosyvoice2/run.sh

@@ -33,7 +33,7 @@ fi
 
 
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
   log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
   log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
-  modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path 
+  modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
   python3 pretrained_to_huggingface.py \
   python3 pretrained_to_huggingface.py \
     --pretrained-cosyvoice2-path $model_scope_model_path \
     --pretrained-cosyvoice2-path $model_scope_model_path \
     --save-path $sft_model_path
     --save-path $sft_model_path
@@ -61,7 +61,7 @@ fi
 if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
 if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
   log "stage 1: start token2wav asr server for reward function"
   log "stage 1: start token2wav asr server for reward function"
   python3 token2wav_asr_server.py --number-of-devices 8
   python3 token2wav_asr_server.py --number-of-devices 8
-fi 
+fi
 
 
 exp_name=official_llm_aishell3_grpo
 exp_name=official_llm_aishell3_grpo
 if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
 if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@@ -125,7 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
       --backend fsdp \
       --backend fsdp \
       --local_dir $llm_path/actor \
       --local_dir $llm_path/actor \
       --target_dir $llm_path/merged_hf_model || exit 1
       --target_dir $llm_path/merged_hf_model || exit 1
-fi 
+fi
 
 
 if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
 if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
   log "stage 4: Test the model"
   log "stage 4: Test the model"

+ 1 - 3
examples/grpo/cosyvoice2/scripts/offline-decode-files.py

@@ -1,5 +1,3 @@
-#!/usr/bin/env python3
-#
 # Copyright (c)  2023 by manyeyes
 # Copyright (c)  2023 by manyeyes
 # Copyright (c)  2023  Xiaomi Corporation
 # Copyright (c)  2023  Xiaomi Corporation
 
 
@@ -195,7 +193,7 @@ def write_error_stats(
             hyp = list("".join(hyp))
             hyp = list("".join(hyp))
             results[i] = (cut_id, ref, hyp)
             results[i] = (cut_id, ref, hyp)
 
 
-    for cut_id, ref, hyp in results:
+    for _cut_id, ref, hyp in results:
         ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
         ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
         for ref_word, hyp_word in ali:
         for ref_word, hyp_word in ali:
             if ref_word == ERR:
             if ref_word == ERR:

+ 1 - 1
examples/grpo/cosyvoice2/token2wav_asr_server.py

@@ -295,7 +295,7 @@ def main():
         metrics_port=8002,
         metrics_port=8002,
     )
     )
 
 
-    device_ids = [i for i in range(args.number_of_devices)]
+    device_ids = list(range(args.number_of_devices))
     device_ids = device_ids * args.number_of_instances_per_device
     device_ids = device_ids * args.number_of_instances_per_device
 
 
     with Triton(config=triton_config) as triton:
     with Triton(config=triton_config) as triton:

+ 141 - 0
runtime/triton_trtllm/README.DIT.md

@@ -0,0 +1,141 @@
+## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM
+
+Contributed by Yuekai Zhang (NVIDIA).
+
+This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM.
+
+### Quick Start
+
+Launch the service directly with Docker Compose:
+```sh
+docker compose -f docker-compose.dit.yml up
+```
+
+### Build the Docker Image
+
+To build the image from scratch:
+```sh
+docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
+```
+
+### Run a Docker Container
+```sh
+your_mount_dir=/mnt:/mnt
+docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
+```
+
+### Understanding `run_stepaudio2_dit_token2wav.sh`
+
+The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages.
+
+You can run a subset of stages with:
+```sh
+bash run_stepaudio2_dit_token2wav.sh <start_stage> <stop_stage>
+```
+- `<start_stage>`: The stage to start from.
+- `<stop_stage>`: The stage to stop after.
+
+**Stages:**
+
+- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories.
+- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models.
+- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines.
+- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`.
+- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM.
+- **Stage 4**: Runs the gRPC benchmark client for performance testing.
+- **Stage 5**: Runs the offline TTS inference benchmark test.
+- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model.
+- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3.
+- **Stage 8**: Runs the benchmark client for the disaggregated server configuration.
+### Export Models and Launch Server
+
+Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
+```sh
+# This command runs stages 0, 1, 2, and 3
+bash run_stepaudio2_dit_token2wav.sh 0 3
+```
+
+### Benchmark with client-server mode
+
+To benchmark the running Triton server, run stage 4:
+```sh
+bash run_stepaudio2_dit_token2wav.sh 4 4
+
+# You can customize parameters such as the number of tasks inside the script.
+```
+The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
+
+#### Total Request Latency
+
+| Concurrent Tasks | RTF    | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
+| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
+| 1                | 0.1228 | 833.66       | 779.98               | 1297.05              | 1555.97              | 1653.02              |
+| 2                | 0.0901 | 1166.23      | 1124.69              | 1762.76              | 1900.64              | 2204.14              |
+| 4                | 0.0741 | 1849.30      | 1759.42              | 2624.50              | 2822.20              | 3128.42              |
+| 6                | 0.0774 | 2936.13      | 3054.64              | 3849.60              | 3900.49              | 4245.79              |
+| 8                | 0.0691 | 3408.56      | 3434.98              | 4547.13              | 5047.76              | 5346.53              |
+| 10               | 0.0707 | 4306.56      | 4343.44              | 5769.64              | 5876.09              | 5939.79              |
+
+#### First Chunk Latency
+
+| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
+| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
+| 1                | 197.50       | 196.13               | 214.65               | 215.96               | 229.21               |
+| 2                |  281.15       | 278.20               | 345.18               | 361.79               | 395.97               |
+| 4                |  510.65       | 530.50               | 630.13               | 642.44               | 666.65               |
+| 6                |  921.54       | 918.86               | 1079.97              | 1265.22              | 1524.41              |
+| 8                |  1019.95      | 1085.26              | 1371.05              | 1402.24              | 1410.66              |
+| 10               |  1214.98      | 1293.54              | 1575.36              | 1654.51              | 2161.76              |
+
+### Benchmark with offline inference mode
+For offline inference mode benchmark, please run stage 5:
+```sh
+bash run_stepaudio2_dit_token2wav.sh 5 5
+```
+
+The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
+
+#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav)
+| Backend | Batch Size | llm_time_seconds  | total_time_seconds | RTF |
+|---------|------------|------------------|-----------------------|--|
+| TRTLLM | 16 | 2.01 |  5.03 | 0.0292 |
+
+
+### Disaggregated Server
+When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs.
+
+The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU.
+
+| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) |
+|---|---|---|---|---|---|---|
+| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 |
+| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 |
+| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 |
+| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 |
+| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 |
+| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 |
+| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 |
+| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 |
+| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 |
+| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 |
+| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 |
+| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 |
+| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 |
+| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 |
+| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 |
+| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 |
+| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 |
+| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 |
+| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 |
+| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 |
+| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 |
+| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 |
+| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 |
+| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 |
+
+The `concurrent_task_per_gpu` is calculated as:
+`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))`
+
+### Acknowledgements
+
+This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

+ 187 - 121
runtime/triton_trtllm/client_grpc.py

@@ -43,9 +43,9 @@ python3 client_grpc.py \
 import argparse
 import argparse
 import asyncio
 import asyncio
 import json
 import json
-import queue  # Added
-import uuid  # Added
-import functools  # Added
+import queue
+import uuid
+import functools
 
 
 import os
 import os
 import time
 import time
@@ -55,16 +55,16 @@ from pathlib import Path
 import numpy as np
 import numpy as np
 import soundfile as sf
 import soundfile as sf
 import tritonclient
 import tritonclient
-import tritonclient.grpc.aio as grpcclient_aio  # Renamed original import
-import tritonclient.grpc as grpcclient_sync  # Added sync client import
-from tritonclient.utils import np_to_triton_dtype, InferenceServerException  # Added InferenceServerException
+import tritonclient.grpc.aio as grpcclient_aio
+import tritonclient.grpc as grpcclient_sync
+from tritonclient.utils import np_to_triton_dtype, InferenceServerException
 
 
 
 
-# --- Added UserData and callback ---
 class UserData:
 class UserData:
     def __init__(self):
     def __init__(self):
         self._completed_requests = queue.Queue()
         self._completed_requests = queue.Queue()
         self._first_chunk_time = None
         self._first_chunk_time = None
+        self._second_chunk_time = None
         self._start_time = None
         self._start_time = None
 
 
     def record_start_time(self):
     def record_start_time(self):
@@ -75,39 +75,43 @@ class UserData:
             return self._first_chunk_time - self._start_time
             return self._first_chunk_time - self._start_time
         return None
         return None
 
 
+    def get_second_chunk_latency(self):
+        if self._first_chunk_time and self._second_chunk_time:
+            return self._second_chunk_time - self._first_chunk_time
+        return None
+
 
 
 def callback(user_data, result, error):
 def callback(user_data, result, error):
-    if user_data._first_chunk_time is None and not error:
-        user_data._first_chunk_time = time.time()  # Record time of first successful chunk
+    if not error:
+        if user_data._first_chunk_time is None:
+            user_data._first_chunk_time = time.time()
+        elif user_data._second_chunk_time is None:
+            user_data._second_chunk_time = time.time()
+
     if error:
     if error:
         user_data._completed_requests.put(error)
         user_data._completed_requests.put(error)
     else:
     else:
         user_data._completed_requests.put(result)
         user_data._completed_requests.put(result)
-# --- End Added UserData and callback ---
+
+
+def stream_callback(user_data_map, result, error):
+    request_id = None
+    if error:
+        print(f"An error occurred in the stream callback: {error}")
+    else:
+        request_id = result.get_response().id
+
+    if request_id:
+        user_data = user_data_map.get(request_id)
+        if user_data:
+            callback(user_data, result, error)
+        else:
+            print(f"Warning: Could not find user_data for request_id {request_id}")
 
 
 
 
 def write_triton_stats(stats, summary_file):
 def write_triton_stats(stats, summary_file):
     with open(summary_file, "w") as summary_f:
     with open(summary_file, "w") as summary_f:
         model_stats = stats["model_stats"]
         model_stats = stats["model_stats"]
-        # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
-        summary_f.write(
-            "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
-        )
-        summary_f.write("To learn more about the log, please refer to: \n")
-        summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
-        summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
-        summary_f.write(
-            "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
-        )
-        summary_f.write(
-            "However, there is a trade-off between the increased queue time and the increased batch size. \n"
-        )
-        summary_f.write(
-            "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
-        )
-        summary_f.write(
-            "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
-        )
         for model_state in model_stats:
         for model_state in model_stats:
             if "last_inference" not in model_state:
             if "last_inference" not in model_state:
                 continue
                 continue
@@ -118,7 +122,10 @@ def write_triton_stats(stats, summary_file):
             total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
             total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
             total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
             total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
             summary_f.write(
             summary_f.write(
-                f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"  # noqa
+                f"queue time {total_queue_time_s:<5.2f} s, "
+                f"compute infer time {total_infer_time_s:<5.2f} s, "
+                f"compute input time {total_input_time_s:<5.2f} s, "
+                f"compute output time {total_output_time_s:<5.2f} s \n"
             )
             )
             model_batch_stats = model_state["batch_stats"]
             model_batch_stats = model_state["batch_stats"]
             for batch in model_batch_stats:
             for batch in model_batch_stats:
@@ -127,21 +134,86 @@ def write_triton_stats(stats, summary_file):
                 compute_output = batch["compute_output"]
                 compute_output = batch["compute_output"]
                 compute_infer = batch["compute_infer"]
                 compute_infer = batch["compute_infer"]
                 batch_count = int(compute_infer["count"])
                 batch_count = int(compute_infer["count"])
+                if batch_count == 0:
+                    continue
                 assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
                 assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
                 compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
                 compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
                 compute_input_time_ms = int(compute_input["ns"]) / 1e6
                 compute_input_time_ms = int(compute_input["ns"]) / 1e6
                 compute_output_time_ms = int(compute_output["ns"]) / 1e6
                 compute_output_time_ms = int(compute_output["ns"]) / 1e6
                 summary_f.write(
                 summary_f.write(
-                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"  # noqa
+                    f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
+                    f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
+                    f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
+                    f"{compute_infer_time_ms / batch_count:.2f} ms, "
+                    f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
+                    f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
                 )
                 )
                 summary_f.write(
                 summary_f.write(
-                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "  # noqa
+                    f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
                 )
                 )
                 summary_f.write(
                 summary_f.write(
-                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"  # noqa
+                    f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
                 )
                 )
 
 
 
 
+def subtract_stats(stats_after, stats_before):
+    """Subtracts two Triton inference statistics objects."""
+    stats_diff = json.loads(json.dumps(stats_after))
+
+    model_stats_before_map = {
+        s["name"]: {
+            "version": s["version"],
+            "last_inference": s.get("last_inference", 0),
+            "inference_count": s.get("inference_count", 0),
+            "execution_count": s.get("execution_count", 0),
+            "inference_stats": s.get("inference_stats", {}),
+            "batch_stats": s.get("batch_stats", []),
+        }
+        for s in stats_before["model_stats"]
+    }
+
+    for model_stat_after in stats_diff["model_stats"]:
+        model_name = model_stat_after["name"]
+        if model_name in model_stats_before_map:
+            model_stat_before = model_stats_before_map[model_name]
+
+            model_stat_after["inference_count"] = str(
+                int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
+            )
+            model_stat_after["execution_count"] = str(
+                int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
+            )
+
+            if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
+                for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
+                    if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
+                        if "ns" in model_stat_after["inference_stats"][key]:
+                            ns_after = int(model_stat_after["inference_stats"][key]["ns"])
+                            ns_before = int(model_stat_before["inference_stats"][key]["ns"])
+                            model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
+                        if "count" in model_stat_after["inference_stats"][key]:
+                            count_after = int(model_stat_after["inference_stats"][key]["count"])
+                            count_before = int(model_stat_before["inference_stats"][key]["count"])
+                            model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
+
+            if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
+                batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
+                for batch_stat_after in model_stat_after["batch_stats"]:
+                    bs = batch_stat_after["batch_size"]
+                    if bs in batch_stats_before_map:
+                        batch_stat_before = batch_stats_before_map[bs]
+                        for key in ["compute_input", "compute_infer", "compute_output"]:
+                            if key in batch_stat_after and key in batch_stat_before:
+                                count_after = int(batch_stat_after[key]["count"])
+                                count_before = int(batch_stat_before[key]["count"])
+                                batch_stat_after[key]["count"] = str(count_after - count_before)
+
+                                ns_after = int(batch_stat_after[key]["ns"])
+                                ns_before = int(batch_stat_before[key]["ns"])
+                                batch_stat_after[key]["ns"] = str(ns_after - ns_before)
+    return stats_diff
+
+
 def get_args():
 def get_args():
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
 
@@ -209,7 +281,8 @@ def get_args():
         choices=[
         choices=[
             "f5_tts",
             "f5_tts",
             "spark_tts",
             "spark_tts",
-            "cosyvoice2"],
+            "cosyvoice2",
+            "cosyvoice2_dit"],
         help="triton model_repo module name to request",
         help="triton model_repo module name to request",
     )
     )
 
 
@@ -243,7 +316,6 @@ def get_args():
         help="log directory",
         help="log directory",
     )
     )
 
 
-    # --- Added arguments ---
     parser.add_argument(
     parser.add_argument(
         "--mode",
         "--mode",
         type=str,
         type=str,
@@ -260,8 +332,8 @@ def get_args():
 
 
     parser.add_argument(
     parser.add_argument(
         "--use-spk2info-cache",
         "--use-spk2info-cache",
-        type=bool,
-        default=False,
+        type=str,
+        default="False",
         help="Use spk2info cache for reference audio.",
         help="Use spk2info cache for reference audio.",
     )
     )
 
 
@@ -284,39 +356,33 @@ def load_audio(wav_path, target_sample_rate=16000):
 
 
 
 
 def prepare_request_input_output(
 def prepare_request_input_output(
-    protocol_client,  # Can be grpcclient_aio or grpcclient_sync
+    protocol_client,
     waveform,
     waveform,
     reference_text,
     reference_text,
     target_text,
     target_text,
     sample_rate=16000,
     sample_rate=16000,
-    padding_duration: int = None,  # Optional padding for offline mode
+    padding_duration: int = None,
     use_spk2info_cache: bool = False
     use_spk2info_cache: bool = False
 ):
 ):
     """Prepares inputs for Triton inference (offline or streaming)."""
     """Prepares inputs for Triton inference (offline or streaming)."""
     assert len(waveform.shape) == 1, "waveform should be 1D"
     assert len(waveform.shape) == 1, "waveform should be 1D"
     lengths = np.array([[len(waveform)]], dtype=np.int32)
     lengths = np.array([[len(waveform)]], dtype=np.int32)
 
 
-    # Apply padding only if padding_duration is provided (for offline)
     if padding_duration:
     if padding_duration:
         duration = len(waveform) / sample_rate
         duration = len(waveform) / sample_rate
-        # Estimate target duration based on text length ratio (crude estimation)
-        # Avoid division by zero if reference_text is empty
         if reference_text:
         if reference_text:
             estimated_target_duration = duration / len(reference_text) * len(target_text)
             estimated_target_duration = duration / len(reference_text) * len(target_text)
         else:
         else:
-            estimated_target_duration = duration  # Assume target duration similar to reference if no text
+            estimated_target_duration = duration
 
 
-        # Calculate required samples based on estimated total duration
         required_total_samples = padding_duration * sample_rate * (
         required_total_samples = padding_duration * sample_rate * (
             (int(estimated_target_duration + duration) // padding_duration) + 1
             (int(estimated_target_duration + duration) // padding_duration) + 1
         )
         )
         samples = np.zeros((1, required_total_samples), dtype=np.float32)
         samples = np.zeros((1, required_total_samples), dtype=np.float32)
         samples[0, : len(waveform)] = waveform
         samples[0, : len(waveform)] = waveform
     else:
     else:
-        # No padding for streaming or if padding_duration is None
         samples = waveform.reshape(1, -1).astype(np.float32)
         samples = waveform.reshape(1, -1).astype(np.float32)
 
 
-    # Common input creation logic
     inputs = [
     inputs = [
         protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
         protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
         protocol_client.InferInput(
         protocol_client.InferInput(
@@ -355,12 +421,8 @@ def run_sync_streaming_inference(
 ):
 ):
     """Helper function to run the blocking sync streaming call."""
     """Helper function to run the blocking sync streaming call."""
     start_time_total = time.time()
     start_time_total = time.time()
-    user_data.record_start_time()  # Record start time for first chunk latency calculation
-
-    # Establish stream
-    sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
+    user_data.record_start_time()
 
 
-    # Send request
     sync_triton_client.async_stream_infer(
     sync_triton_client.async_stream_infer(
         model_name,
         model_name,
         inputs,
         inputs,
@@ -369,91 +431,76 @@ def run_sync_streaming_inference(
         enable_empty_final_response=True,
         enable_empty_final_response=True,
     )
     )
 
 
-    # Process results
     audios = []
     audios = []
     while True:
     while True:
         try:
         try:
-            result = user_data._completed_requests.get()  # Add timeout
+            result = user_data._completed_requests.get(timeout=200)
             if isinstance(result, InferenceServerException):
             if isinstance(result, InferenceServerException):
                 print(f"Received InferenceServerException: {result}")
                 print(f"Received InferenceServerException: {result}")
-                sync_triton_client.stop_stream()
-                return None, None, None  # Indicate error
-            # Get response metadata
+                return None, None, None, None
             response = result.get_response()
             response = result.get_response()
             final = response.parameters["triton_final_response"].bool_param
             final = response.parameters["triton_final_response"].bool_param
             if final is True:
             if final is True:
                 break
                 break
 
 
             audio_chunk = result.as_numpy("waveform").reshape(-1)
             audio_chunk = result.as_numpy("waveform").reshape(-1)
-            if audio_chunk.size > 0:  # Only append non-empty chunks
+            if audio_chunk.size > 0:
                 audios.append(audio_chunk)
                 audios.append(audio_chunk)
             else:
             else:
                 print("Warning: received empty audio chunk.")
                 print("Warning: received empty audio chunk.")
 
 
         except queue.Empty:
         except queue.Empty:
             print(f"Timeout waiting for response for request id {request_id}")
             print(f"Timeout waiting for response for request id {request_id}")
-            sync_triton_client.stop_stream()
-            return None, None, None  # Indicate error
+            return None, None, None, None
 
 
-    sync_triton_client.stop_stream()
     end_time_total = time.time()
     end_time_total = time.time()
     total_request_latency = end_time_total - start_time_total
     total_request_latency = end_time_total - start_time_total
     first_chunk_latency = user_data.get_first_chunk_latency()
     first_chunk_latency = user_data.get_first_chunk_latency()
+    second_chunk_latency = user_data.get_second_chunk_latency()
 
 
-    # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
-    actual_duration = 0
     if audios:
     if audios:
-        # Only spark_tts model uses cross-fade
         if model_name == "spark_tts":
         if model_name == "spark_tts":
             cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
             cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
             fade_out = np.linspace(1, 0, cross_fade_samples)
             fade_out = np.linspace(1, 0, cross_fade_samples)
             fade_in = np.linspace(0, 1, cross_fade_samples)
             fade_in = np.linspace(0, 1, cross_fade_samples)
             reconstructed_audio = None
             reconstructed_audio = None
 
 
-            # Simplified reconstruction based on client_grpc_streaming.py
             if not audios:
             if not audios:
                 print("Warning: No audio chunks received.")
                 print("Warning: No audio chunks received.")
-                reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
+                reconstructed_audio = np.array([], dtype=np.float32)
             elif len(audios) == 1:
             elif len(audios) == 1:
                 reconstructed_audio = audios[0]
                 reconstructed_audio = audios[0]
             else:
             else:
-                reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
+                reconstructed_audio = audios[0][:-cross_fade_samples]
                 for i in range(1, len(audios)):
                 for i in range(1, len(audios)):
-                    # Cross-fade section
                     cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
                     cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
                                            audios[i - 1][-cross_fade_samples:] * fade_out)
                                            audios[i - 1][-cross_fade_samples:] * fade_out)
-                    # Middle section of the current chunk
                     middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
                     middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
-                    # Concatenate
                     reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
                     reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
-                # Add the last part of the final chunk
                 reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
                 reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
 
 
             if reconstructed_audio is not None and reconstructed_audio.size > 0:
             if reconstructed_audio is not None and reconstructed_audio.size > 0:
                 actual_duration = len(reconstructed_audio) / save_sample_rate
                 actual_duration = len(reconstructed_audio) / save_sample_rate
-                # Save reconstructed audio
                 sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
                 sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
             else:
             else:
                 print("Warning: No audio chunks received or reconstructed.")
                 print("Warning: No audio chunks received or reconstructed.")
-                actual_duration = 0  # Set duration to 0 if no audio
+                actual_duration = 0
         else:
         else:
             reconstructed_audio = np.concatenate(audios)
             reconstructed_audio = np.concatenate(audios)
-            print(f"reconstructed_audio: {reconstructed_audio.shape}")
             actual_duration = len(reconstructed_audio) / save_sample_rate
             actual_duration = len(reconstructed_audio) / save_sample_rate
-            # Save reconstructed audio
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
 
 
     else:
     else:
         print("Warning: No audio chunks received.")
         print("Warning: No audio chunks received.")
         actual_duration = 0
         actual_duration = 0
 
 
-    return total_request_latency, first_chunk_latency, actual_duration
+    return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
 
 
 
 
 async def send_streaming(
 async def send_streaming(
     manifest_item_list: list,
     manifest_item_list: list,
     name: str,
     name: str,
-    server_url: str,  # Changed from sync_triton_client
+    server_url: str,
     protocol_client: types.ModuleType,
     protocol_client: types.ModuleType,
     log_interval: int,
     log_interval: int,
     model_name: str,
     model_name: str,
@@ -466,11 +513,13 @@ async def send_streaming(
     total_duration = 0.0
     total_duration = 0.0
     latency_data = []
     latency_data = []
     task_id = int(name[5:])
     task_id = int(name[5:])
-    sync_triton_client = None  # Initialize client variable
+    sync_triton_client = None
+    user_data_map = {}
 
 
-    try:  # Wrap in try...finally to ensure client closing
+    try:
         print(f"{name}: Initializing sync client for streaming...")
         print(f"{name}: Initializing sync client for streaming...")
-        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)  # Create client here
+        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
+        sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
 
 
         print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
         print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
         for i, item in enumerate(manifest_item_list):
         for i, item in enumerate(manifest_item_list):
@@ -490,12 +539,13 @@ async def send_streaming(
                     padding_duration=padding_duration,
                     padding_duration=padding_duration,
                     use_spk2info_cache=use_spk2info_cache
                     use_spk2info_cache=use_spk2info_cache
                 )
                 )
+
                 request_id = str(uuid.uuid4())
                 request_id = str(uuid.uuid4())
                 user_data = UserData()
                 user_data = UserData()
+                user_data_map[request_id] = user_data
 
 
                 audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
                 audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
-
-                total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
+                total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
                     run_sync_streaming_inference,
                     run_sync_streaming_inference,
                     sync_triton_client,
                     sync_triton_client,
                     model_name,
                     model_name,
@@ -509,12 +559,18 @@ async def send_streaming(
                 )
                 )
 
 
                 if total_request_latency is not None:
                 if total_request_latency is not None:
-                    print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
-                    latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
+                    print(
+                        f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
+                        f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
+                        f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
+                    )
+                    latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
                     total_duration += actual_duration
                     total_duration += actual_duration
                 else:
                 else:
                     print(f"{name}: Item {i} failed.")
                     print(f"{name}: Item {i} failed.")
 
 
+                del user_data_map[request_id]
+
             except FileNotFoundError:
             except FileNotFoundError:
                 print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
                 print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
             except Exception as e:
             except Exception as e:
@@ -522,10 +578,11 @@ async def send_streaming(
                 import traceback
                 import traceback
                 traceback.print_exc()
                 traceback.print_exc()
 
 
-    finally:  # Ensure client is closed
+    finally:
         if sync_triton_client:
         if sync_triton_client:
             try:
             try:
-                print(f"{name}: Closing sync client...")
+                print(f"{name}: Closing stream and sync client...")
+                sync_triton_client.stop_stream()
                 sync_triton_client.close()
                 sync_triton_client.close()
             except Exception as e:
             except Exception as e:
                 print(f"{name}: Error closing sync client: {e}")
                 print(f"{name}: Error closing sync client: {e}")
@@ -550,7 +607,6 @@ async def send(
     latency_data = []
     latency_data = []
     task_id = int(name[5:])
     task_id = int(name[5:])
 
 
-    print(f"manifest_item_list: {manifest_item_list}")
     for i, item in enumerate(manifest_item_list):
     for i, item in enumerate(manifest_item_list):
         if i % log_interval == 0:
         if i % log_interval == 0:
             print(f"{name}: {i}/{len(manifest_item_list)}")
             print(f"{name}: {i}/{len(manifest_item_list)}")
@@ -591,7 +647,6 @@ def load_manifests(manifest_path):
             assert len(line.strip().split("|")) == 4
             assert len(line.strip().split("|")) == 4
             utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
             utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
             utt = Path(utt).stem
             utt = Path(utt).stem
-            # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
             if not os.path.isabs(prompt_wav):
             if not os.path.isabs(prompt_wav):
                 prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
                 prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
             manifest_list.append(
             manifest_list.append(
@@ -632,23 +687,17 @@ async def main():
     args = get_args()
     args = get_args()
     url = f"{args.server_addr}:{args.server_port}"
     url = f"{args.server_addr}:{args.server_port}"
 
 
-    # --- Client Initialization based on mode ---
     triton_client = None
     triton_client = None
     protocol_client = None
     protocol_client = None
     if args.mode == "offline":
     if args.mode == "offline":
         print("Initializing gRPC client for offline mode...")
         print("Initializing gRPC client for offline mode...")
-        # Use the async client for offline tasks
         triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
         triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
         protocol_client = grpcclient_aio
         protocol_client = grpcclient_aio
     elif args.mode == "streaming":
     elif args.mode == "streaming":
         print("Initializing gRPC client for streaming mode...")
         print("Initializing gRPC client for streaming mode...")
-        # Use the sync client for streaming tasks, handled via asyncio.to_thread
-        # We will create one sync client instance PER TASK inside send_streaming.
-        # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
-        protocol_client = grpcclient_sync  # protocol client for input prep
+        protocol_client = grpcclient_sync
     else:
     else:
         raise ValueError(f"Invalid mode: {args.mode}")
         raise ValueError(f"Invalid mode: {args.mode}")
-    # --- End Client Initialization ---
 
 
     if args.reference_audio:
     if args.reference_audio:
         args.num_tasks = 1
         args.num_tasks = 1
@@ -682,15 +731,24 @@ async def main():
     else:
     else:
         manifest_item_list = load_manifests(args.manifest_path)
         manifest_item_list = load_manifests(args.manifest_path)
 
 
+    stats_client = None
+    stats_before = None
+    try:
+        print("Initializing temporary async client for fetching stats...")
+        stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
+        print("Fetching inference statistics before running tasks...")
+        stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
+    except Exception as e:
+        print(f"Could not retrieve statistics before running tasks: {e}")
+
     num_tasks = min(args.num_tasks, len(manifest_item_list))
     num_tasks = min(args.num_tasks, len(manifest_item_list))
     manifest_item_list = split_data(manifest_item_list, num_tasks)
     manifest_item_list = split_data(manifest_item_list, num_tasks)
 
 
     os.makedirs(args.log_dir, exist_ok=True)
     os.makedirs(args.log_dir, exist_ok=True)
-
+    args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
     tasks = []
     tasks = []
     start_time = time.time()
     start_time = time.time()
     for i in range(num_tasks):
     for i in range(num_tasks):
-        # --- Task Creation based on mode ---
         if args.mode == "offline":
         if args.mode == "offline":
             task = asyncio.create_task(
             task = asyncio.create_task(
                 send(
                 send(
@@ -711,7 +769,7 @@ async def main():
                 send_streaming(
                 send_streaming(
                     manifest_item_list[i],
                     manifest_item_list[i],
                     name=f"task-{i}",
                     name=f"task-{i}",
-                    server_url=url,  # Pass URL instead of client
+                    server_url=url,
                     protocol_client=protocol_client,
                     protocol_client=protocol_client,
                     log_interval=args.log_interval,
                     log_interval=args.log_interval,
                     model_name=args.model_name,
                     model_name=args.model_name,
@@ -722,7 +780,6 @@ async def main():
                     use_spk2info_cache=args.use_spk2info_cache,
                     use_spk2info_cache=args.use_spk2info_cache,
                 )
                 )
             )
             )
-        # --- End Task Creation ---
         tasks.append(task)
         tasks.append(task)
 
 
     ans_list = await asyncio.gather(*tasks)
     ans_list = await asyncio.gather(*tasks)
@@ -735,7 +792,7 @@ async def main():
     for ans in ans_list:
     for ans in ans_list:
         if ans:
         if ans:
             total_duration += ans[0]
             total_duration += ans[0]
-            latency_data.extend(ans[1])  # Use extend for list of lists
+            latency_data.extend(ans[1])
         else:
         else:
             print("Warning: A task returned None, possibly due to an error.")
             print("Warning: A task returned None, possibly due to an error.")
 
 
@@ -751,10 +808,8 @@ async def main():
     s += f"({total_duration / 3600:.2f} hours)\n"
     s += f"({total_duration / 3600:.2f} hours)\n"
     s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
     s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
 
 
-    # --- Statistics Reporting based on mode ---
     if latency_data:
     if latency_data:
         if args.mode == "offline":
         if args.mode == "offline":
-            # Original offline latency calculation
             latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
             latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
             if latency_list:
             if latency_list:
                 latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
                 latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
@@ -769,9 +824,9 @@ async def main():
                 s += "No latency data collected for offline mode.\n"
                 s += "No latency data collected for offline mode.\n"
 
 
         elif args.mode == "streaming":
         elif args.mode == "streaming":
-            # Calculate stats for total request latency and first chunk latency
-            total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
-            first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
+            total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
+            first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
+            second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
 
 
             s += "\n--- Total Request Latency ---\n"
             s += "\n--- Total Request Latency ---\n"
             if total_latency_list:
             if total_latency_list:
@@ -798,9 +853,21 @@ async def main():
                 s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
                 s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
             else:
             else:
                 s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
                 s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
+
+            s += "\n--- Second Chunk Latency ---\n"
+            if second_chunk_latency_list:
+                avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
+                variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
+                s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
+                s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
+                s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
+                s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
+                s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
+                s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
+            else:
+                s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
     else:
     else:
         s += "No latency data collected.\n"
         s += "No latency data collected.\n"
-    # --- End Statistics Reporting ---
 
 
     print(s)
     print(s)
     if args.manifest_path:
     if args.manifest_path:
@@ -810,26 +877,27 @@ async def main():
     elif args.reference_audio:
     elif args.reference_audio:
         name = Path(args.reference_audio).stem
         name = Path(args.reference_audio).stem
     else:
     else:
-        name = "results"  # Default name if no manifest/split/audio provided
+        name = "results"
     with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
     with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
         f.write(s)
         f.write(s)
 
 
-    # --- Statistics Fetching using temporary Async Client ---
-    # Use a separate async client for fetching stats regardless of mode
-    stats_client = None
     try:
     try:
-        print("Initializing temporary async client for fetching stats...")
-        stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
-        print("Fetching inference statistics...")
-        # Fetching for all models, filtering might be needed depending on server setup
-        stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
-        print("Fetching model config...")
-        metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
+        if stats_client and stats_before:
+            print("Fetching inference statistics after running tasks...")
+            stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
+
+            print("Calculating statistics difference...")
+            stats = subtract_stats(stats_after, stats_before)
 
 
-        write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
+            print("Fetching model config...")
+            metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
 
 
-        with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
-            json.dump(metadata, f, indent=4)
+            write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
+
+            with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
+                json.dump(metadata, f, indent=4)
+        else:
+            print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
 
 
     except Exception as e:
     except Exception as e:
         print(f"Could not retrieve statistics or config: {e}")
         print(f"Could not retrieve statistics or config: {e}")
@@ -840,11 +908,9 @@ async def main():
                 await stats_client.close()
                 await stats_client.close()
             except Exception as e:
             except Exception as e:
                 print(f"Error closing async stats client: {e}")
                 print(f"Error closing async stats client: {e}")
-    # --- End Statistics Fetching ---
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
     async def run_main():
     async def run_main():
         try:
         try:
             await main()
             await main()

+ 0 - 1
runtime/triton_trtllm/client_http.py

@@ -25,7 +25,6 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 import requests
 import requests
 import soundfile as sf
 import soundfile as sf
-import json
 import numpy as np
 import numpy as np
 import argparse
 import argparse
 
 

+ 20 - 0
runtime/triton_trtllm/docker-compose.dit.yml

@@ -0,0 +1,20 @@
+services:
+  tts:
+    image: soar97/triton-cosyvoice:25.06
+    shm_size: '1gb'
+    ports:
+      - "8000:8000"
+      - "8001:8001"
+      - "8002:8002"
+    environment:
+      - PYTHONIOENCODING=utf-8
+      - MODEL_ID=${MODEL_ID}
+    deploy:
+      resources:
+        reservations:
+          devices:
+            - driver: nvidia
+              device_ids: ['0']
+              capabilities: [gpu]
+    command: >
+      /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3"

+ 0 - 3
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -25,12 +25,9 @@
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 import json
 import json
-import math
 import os
 import os
-import re
 import threading
 import threading
 import time
 import time
-from typing import Dict, List, Tuple, Optional, Union
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch

+ 394 - 0
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py

@@ -0,0 +1,394 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import math
+import os
+import re
+import time
+from typing import Dict, List, Tuple, Optional, Union
+import asyncio
+import httpx
+
+import numpy as np
+import torch
+from torch.utils.dlpack import from_dlpack, to_dlpack
+import triton_python_backend_utils as pb_utils
+from transformers import AutoTokenizer
+
+import torchaudio
+
+
+from matcha.utils.audio import mel_spectrogram
+
+
+ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
+
+
+def parse_speech_token_string(response_text: str) -> List[int]:
+    """
+    Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
+    """
+    speech_tokens = response_text.strip().split('><')
+    if len(speech_tokens) > 1:
+        # Add back the missing '<' and '>' for proper parsing
+        speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
+        speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
+
+    speech_ids = []
+    for token_str in speech_tokens:
+        match = re.match(r'<\|s_(\d+)\|>', token_str)
+        if match:
+            speech_ids.append(int(match.group(1)))
+    return speech_ids
+
+
+class TritonPythonModel:
+    """Triton Python model for Spark TTS.
+
+    This model orchestrates the end-to-end TTS pipeline by coordinating
+    between audio tokenizer, LLM, and vocoder components.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        self.logger = pb_utils.Logger
+        # Parse model parameters
+        self.model_config = json.loads(args['model_config'])
+        parameters = self.model_config['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+        self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
+        self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
+
+        # Initialize tokenizer
+        llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
+        self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
+        self.prompt_template = "<|sos|>{input_text}<|task_id|>"
+        self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
+
+        self.device = torch.device("cuda")
+        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
+
+        self.token_frame_rate = 25
+        self.flow_pre_lookahead_len = 3
+        self.token_hop_len = 15
+
+        self.http_client = httpx.AsyncClient()
+        self.api_base = "http://localhost:8000/v1/chat/completions"
+        self.speaker_cache = {}
+
+    def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
+        """Converts a tensor or list of speech token IDs to a string representation."""
+        if isinstance(speech_tokens, torch.Tensor):
+            # Ensure tensor is on CPU and flattened
+            speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
+
+        speech_id_str = ""
+        for token_id in speech_tokens:
+            # Convert token ID back to the speech number N
+            token_num = token_id - ORIGINAL_VOCAB_SIZE
+            speech_id_str += f"<|s_{token_num}|>"
+        return speech_id_str
+
+    async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
+        """
+        Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
+        """
+        full_text = f"{reference_text}{target_text}"
+        prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
+
+        chat = [
+            {"role": "user", "content": full_text},
+            {"role": "assistant", "content": prompt_speech_tokens_str}
+        ]
+
+        payload = {
+            "model": "trt_engines_bfloat16",
+            "messages": chat,
+            "max_tokens": 750,
+            "temperature": 0.8,
+            "top_p": 0.95,
+            "top_k": 50,
+            "repetition_penalty": 1.1,
+            "stop": ["<|eos1|>", "<|eos|>"],
+            "stream": True,
+        }
+
+        buffer = ""
+        async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
+            response.raise_for_status()
+            async for line in response.aiter_lines():
+                if line.startswith("data: "):
+                    line_data = line[len("data: "):].strip()
+                    if line_data == "[DONE]":
+                        break
+                    try:
+                        json_data = json.loads(line_data)
+                        content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
+                        if content:
+                            buffer += content
+                            while True:
+                                match = re.search(r"<\|s_(\d+)\|>", buffer)
+                                if not match:
+                                    break
+
+                                token_num = int(match.group(1))
+                                final_id = token_num + ORIGINAL_VOCAB_SIZE
+                                yield final_id
+                                buffer = buffer[match.end():]
+                    except json.JSONDecodeError:
+                        self.logger.log_info(f"Skipping non-JSON line: {line_data}")
+                        continue
+
+        # Process any remaining complete tokens in the buffer after the stream ends
+        while True:
+            match = re.search(r"<\|s_(\d+)\|>", buffer)
+            if not match:
+                break
+            token_num = int(match.group(1))
+            final_id = token_num + ORIGINAL_VOCAB_SIZE
+            yield final_id
+            buffer = buffer[match.end():]
+
+    def forward_audio_tokenizer(self, wav, wav_len):
+        """Forward pass through the audio tokenizer component.
+
+        Args:
+            wav: Input waveform tensor
+            wav_len: Waveform length tensor
+
+        Returns:
+            Tuple of global and semantic tokens
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='audio_tokenizer',
+            requested_output_names=['prompt_speech_tokens'],
+            inputs=[wav, wav_len]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
+        prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
+
+        return prompt_speech_tokens
+
+    def forward_speaker_embedding(self, wav):
+        """Forward pass through the speaker embedding component.
+
+        Args:
+            wav: Input waveform tensor
+
+        Returns:
+            Prompt speaker embedding tensor
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='speaker_embedding',
+            requested_output_names=['prompt_spk_embedding'],
+            inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
+        prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
+
+        return prompt_spk_embedding
+
+    async def forward_token2wav(
+            self,
+            index: int,
+            target_speech_tokens: torch.Tensor,
+            request_id: str,
+            reference_wav: object,
+            reference_wav_len: object,
+            finalize: bool = None) -> torch.Tensor:
+        """Forward pass through the vocoder component.
+
+        Args:
+            index: Index of the request
+            target_speech_tokens: Target speech tokens tensor
+            request_id: Request ID
+            reference_wav: Reference waveform tensor
+            reference_wav_len: Reference waveform length tensor
+            finalize: Whether to finalize the request
+
+        Returns:
+            Generated waveform tensor
+        """
+        target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
+        finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
+        inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
+
+        # Create and execute inference request
+        inference_request = pb_utils.InferenceRequest(
+            model_name='token2wav_dit',
+            requested_output_names=[
+                "waveform",
+            ],
+            inputs=inputs_tensor,
+            request_id=request_id,
+            parameters={"priority": index + 1},
+        )
+
+        inference_response = await inference_request.async_exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output waveform
+        waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
+        waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
+
+        return waveform
+
+    def _extract_speech_feat(self, speech):
+        speech_feat = mel_spectrogram(
+            speech,
+            n_fft=1920,
+            num_mels=80,
+            sampling_rate=24000,
+            hop_size=480,
+            win_size=1920,
+            fmin=0,
+            fmax=8000).squeeze(
+            dim=0).transpose(
+            0,
+            1).to(
+                self.device)
+        speech_feat = speech_feat.unsqueeze(dim=0)
+        return speech_feat
+
+    async def _process_request(self, request):
+        request_id = request.request_id()
+
+        reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+        reference_text = reference_text[0][0].decode('utf-8')
+
+        wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
+        wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+
+        if reference_text not in self.speaker_cache:
+            self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
+        prompt_speech_tokens = self.speaker_cache[reference_text]
+
+        target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
+        target_text = target_text[0][0].decode('utf-8')
+
+        if self.decoupled:
+            response_sender = request.get_response_sender()
+
+            semantic_token_ids_arr = []
+            token_offset, chunk_index = 0, 0
+            start_time = time.time()
+            this_token_hop_len = self.token_hop_len
+            async for generated_ids in self.forward_llm_async(
+                target_text=target_text,
+                reference_text=reference_text,
+                prompt_speech_tokens=prompt_speech_tokens,
+            ):
+                if not generated_ids:
+                    break
+                semantic_token_ids_arr.append(generated_ids)
+                while True:
+                    pending_num = len(semantic_token_ids_arr) - token_offset
+                    if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
+                        this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
+                        this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
+                        sub_tts_speech = await self.forward_token2wav(
+                            chunk_index,
+                            this_tts_speech_token, request_id, wav, wav_len, False
+                        )
+                        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+                        inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                        response_sender.send(inference_response)
+
+                        token_offset += this_token_hop_len
+
+                        if self.dynamic_chunk_strategy == "exponential":
+                            this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
+                        elif self.dynamic_chunk_strategy == "equal":
+                            this_token_hop_len = self.token_hop_len
+                        elif self.dynamic_chunk_strategy == "time_based":
+                            # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
+                            cost_time = time.time() - start_time
+                            duration = token_offset / self.token_frame_rate
+                            if chunk_index > 0 and cost_time > 0:
+                                avg_chunk_processing_time = cost_time / (chunk_index + 1)
+                                if avg_chunk_processing_time > 0:
+                                    multiples = (duration - cost_time) / avg_chunk_processing_time
+                                    next_pending_num = len(semantic_token_ids_arr) - token_offset
+                                    if multiples > 4:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
+                                    elif multiples > 2:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
+                                    else:
+                                        this_token_hop_len = self.token_hop_len
+                                    this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
+                        chunk_index += 1
+                    else:
+                        break
+
+            this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
+            sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
+            audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+            inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+            response_sender.send(inference_response)
+
+            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+        else:
+            raise NotImplementedError("Offline TTS mode is not supported")
+
+    async def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated audio
+        """
+        tasks = [
+            asyncio.create_task(self._process_request(request))
+            for request in requests
+        ]
+        await asyncio.gather(*tasks)
+        return None
+
+    def finalize(self):
+        self.logger.log_info("Finalizing CosyVoice DIT model")
+        if hasattr(self, "http_client"):
+            asyncio.run(self.http_client.aclose())

+ 73 - 0
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt

@@ -0,0 +1,73 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "cosyvoice2_dit"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+model_transaction_policy {
+  decoupled: ${decoupled_mode}
+}
+parameters [
+  {
+   key: "llm_tokenizer_dir",
+   value: {string_value:"${llm_tokenizer_dir}"}
+  },
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+    optional: true
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+    optional: true
+  },
+  {
+    name: "reference_text"
+    data_type: TYPE_STRING
+    dims: [1]
+    optional: true
+  },
+  {
+    name: "target_text"
+    data_type: TYPE_STRING
+    dims: [1]
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: ${bls_instance_num}
+    kind: KIND_CPU
+  }
+]

+ 0 - 1
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -28,7 +28,6 @@ import json
 import os
 import os
 
 
 import logging
 import logging
-from typing import List, Dict
 
 
 import torch
 import torch
 from torch.utils.dlpack import to_dlpack
 from torch.utils.dlpack import to_dlpack

+ 142 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py

@@ -0,0 +1,142 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import os
+
+import logging
+from typing import List, Dict
+
+import torch
+from torch.utils.dlpack import to_dlpack
+from torch.nn import functional as F
+
+import triton_python_backend_utils as pb_utils
+
+from hyperpyyaml import load_hyperpyyaml
+from cosyvoice.utils.common import fade_in_out
+from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
+from cosyvoice.utils.common import TrtContextWrapper
+from collections import defaultdict
+import numpy as np
+from .token2wav_dit import CosyVoice2_Token2Wav
+import hashlib
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+
+ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
+
+
+def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
+    """
+    Generates a unique ID for a torch.Tensor.
+    Tensors with the same elements and properties will have the same ID.
+    """
+    # Convert tensor to a byte string
+    tensor_bytes = tensor.numpy().tobytes()
+
+    # Create a SHA-256 hash of the byte string
+    hasher = hashlib.sha256()
+    hasher.update(tensor_bytes)
+
+    return hasher.hexdigest()
+
+
+class TritonPythonModel:
+    """Triton Python model for vocoder.
+
+    This model takes global and semantic tokens as input and generates audio waveforms
+    using the BiCodec vocoder.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        # Parse model parameters
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {key: value["string_value"] for key, value in parameters.items()}
+        model_dir = model_params["model_dir"]
+
+        # Initialize device and vocoder
+        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+        logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
+
+        # FIXME: device id settings
+        self.token2wav_model = CosyVoice2_Token2Wav(
+            model_dir, enable_trt=True, streaming=True
+        )
+        logger.info("Token2Wav initialized successfully")
+
+    def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated waveforms
+        """
+        responses = []
+        # Process each request in batch
+        for request in requests:
+            target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
+            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
+            target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
+            target_speech_tokens = target_speech_tokens.squeeze().tolist()
+
+            finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
+
+            request_id = request.request_id()
+
+            wav_array = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav").as_numpy()
+            wav_len = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav_len").as_numpy().item()
+
+            wav_array = torch.from_numpy(wav_array)
+            wav = wav_array[:, :wav_len].squeeze(0)
+
+            spk_id = get_spk_id_from_prompt_audio(wav)
+
+            audio_hat = self.token2wav_model.forward_streaming(
+                target_speech_tokens, finalize, request_id=request_id,
+                speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
+            )
+
+            outputs = []
+
+            wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
+            outputs.append(wav_tensor)
+            inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
+            responses.append(inference_response)
+
+        return responses

+ 510 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py

@@ -0,0 +1,510 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 token2wav.py --enable-trt || exit 1
+"""
+import torch
+# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
+from flashcosyvoice.modules.hifigan import HiFTGenerator
+from flashcosyvoice.utils.audio import mel_spectrogram
+import torchaudio.compliance.kaldi as kaldi
+import onnxruntime
+import s3tokenizer
+from torch.utils.data import DataLoader
+from datasets import load_dataset
+import torchaudio
+import os
+import logging
+import argparse
+import queue
+import time
+import numpy as np
+from hyperpyyaml import load_hyperpyyaml
+
+
+def fade_in_out(fade_in_mel: torch.Tensor, fade_out_mel: torch.Tensor, window: torch.Tensor):
+    """perform fade_in_out in tensor style
+    """
+    mel_overlap_len = int(window.shape[0] / 2)
+    fade_in_mel = fade_in_mel.clone()
+    fade_in_mel[..., :mel_overlap_len] = \
+        fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
+        fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
+    return fade_in_mel
+
+
+def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
+    import tensorrt as trt
+    logging.info("Converting onnx to trt...")
+    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+    logger = trt.Logger(trt.Logger.INFO)
+    builder = trt.Builder(logger)
+    network = builder.create_network(network_flags)
+    parser = trt.OnnxParser(network, logger)
+    config = builder.create_builder_config()
+    # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB
+    if dtype == torch.float16:
+        config.set_flag(trt.BuilderFlag.FP16)
+
+    profile = builder.create_optimization_profile()
+    # load onnx model
+    with open(onnx_model, "rb") as f:
+        if not parser.parse(f.read()):
+            for error in range(parser.num_errors):
+                print(parser.get_error(error))
+            raise ValueError('failed to parse {}'.format(onnx_model))
+    # set input shapes
+    for i in range(len(trt_kwargs['input_names'])):
+        profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
+    if dtype == torch.float16:
+        tensor_dtype = trt.DataType.HALF
+    elif dtype == torch.bfloat16:
+        tensor_dtype = trt.DataType.BF16
+    elif dtype == torch.float32:
+        tensor_dtype = trt.DataType.FLOAT
+    else:
+        raise ValueError('invalid dtype {}'.format(dtype))
+    # set input and output data type
+    for i in range(network.num_inputs):
+        input_tensor = network.get_input(i)
+        input_tensor.dtype = tensor_dtype
+    for i in range(network.num_outputs):
+        output_tensor = network.get_output(i)
+        output_tensor.dtype = tensor_dtype
+    config.add_optimization_profile(profile)
+    engine_bytes = builder.build_serialized_network(network, config)
+    # save trt engine
+    with open(trt_model, "wb") as f:
+        f.write(engine_bytes)
+    logging.info("Succesfully convert onnx to trt...")
+
+
+class TrtContextWrapper:
+    def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        self.trt_engine = trt_engine
+        self.device = device
+        for _ in range(trt_concurrent):
+            trt_context = trt_engine.create_execution_context()
+            trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
+            assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
+            self.trt_context_pool.put([trt_context, trt_stream])
+        assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
+
+    def acquire_estimator(self):
+        return self.trt_context_pool.get(), self.trt_engine
+
+    def release_estimator(self, context, stream):
+        self.trt_context_pool.put([context, stream])
+
+
+class CosyVoice2_Token2Wav(torch.nn.Module):
+    def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
+        super().__init__()
+        self.device_id = device_id
+        self.device = f"cuda:{device_id}"
+        with open(f"{model_dir}/flow.yaml", "r") as f:
+            configs = load_hyperpyyaml(f)
+            self.flow = configs['flow']
+
+        self.dtype = dtype
+        self.flow.to(self.dtype)
+
+        self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
+        self.flow.to(self.device).eval()
+
+        self.hift = HiFTGenerator()
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        self.spk_model = onnxruntime.InferenceSession(
+            f"{model_dir}/campplus.onnx", sess_options=option,
+            providers=["CPUExecutionProvider"])
+        self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
+
+        gpu = "l20"
+        if enable_trt:
+            if streaming:
+                self.load_trt(
+                    f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
+                    f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
+                    1,
+                    self.dtype, streaming
+                )
+            else:
+                self.load_trt(
+                    f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
+                    f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
+                    1,
+                    self.dtype
+                )
+            self.load_spk_trt(
+                f'{model_dir}/campplus.{gpu}.fp32.trt',
+                f'{model_dir}/campplus.onnx',
+                1,
+                False
+            )
+
+        self.streaming_flow_cache = {}
+        self.speaker_cache = {}
+
+        self.mel_cache_len = 8  # hard-coded, 160ms
+        self.source_cache_len = int(self.mel_cache_len * 480)   # 50hz mel -> 24kHz wave
+        self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
+
+        # hifigan cache for streaming tts
+        self.hift_cache_dict = {}
+
+    def forward_spk_embedding(self, spk_feat):
+        if isinstance(self.spk_model, onnxruntime.InferenceSession):
+            return self.spk_model.run(
+                None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
+            )[0].flatten().tolist()
+        else:
+            [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
+            # NOTE need to synchronize when switching stream
+            with torch.cuda.device(self.device_id):
+                torch.cuda.current_stream().synchronize()
+                spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
+                batch_size = spk_feat.size(0)
+
+                with stream:
+                    spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
+                    output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
+
+                    data_ptrs = [spk_feat.contiguous().data_ptr(),
+                                 output_tensor.contiguous().data_ptr()]
+                    for i, j in enumerate(data_ptrs):
+
+                        spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                    # run trt engine
+                    assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                    torch.cuda.current_stream().synchronize()
+                self.spk_model.release_estimator(spk_model, stream)
+
+            return output_tensor.cpu().numpy().flatten().tolist()
+
+    def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
+        if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
+            trt_kwargs = self.get_spk_trt_kwargs()
+            convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, torch.float32)
+        import tensorrt as trt
+        with open(spk_model, 'rb') as f:
+            spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
+        self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_spk_trt_kwargs(self):
+        min_shape = [(1, 4, 80)]
+        opt_shape = [(1, 500, 80)]
+        max_shape = [(1, 3000, 80)]
+        input_names = ["input"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
+        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
+        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
+            opt_batch_size = 2
+            max_batch_size = 16
+            if streaming:
+                opt_batch_size, max_batch_size = 1, 1  # only support batch size 1 for streaming tts
+            trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
+            convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
+        if streaming:
+            min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
+            opt_shape = [
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 80, 500),
+                (opt_batch_size * 2,), (opt_batch_size * 2, 80), (16, opt_batch_size * 2, 1024, 2),
+                (16, opt_batch_size * 2, 8, 100, 128)
+            ]
+            max_shape = [
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000),
+                (max_batch_size * 2,), (max_batch_size * 2, 80), (16, max_batch_size * 2, 1024, 2),
+                (16, max_batch_size * 2, 8, 1000, 128)
+            ]
+            input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
+        else:
+            min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
+            opt_shape = [
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2, 1, 500), (opt_batch_size * 2, 80, 500),
+                (opt_batch_size * 2, 80, 500), (opt_batch_size * 2,), (opt_batch_size * 2, 80)
+            ]
+            max_shape = [
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000),
+                (max_batch_size * 2, 80, 3000), (max_batch_size * 2,), (max_batch_size * 2, 80)
+            ]
+            input_names = ["x", "mask", "mu", "cond", "t", "spks"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
+        prompt_speech_tokens_list, prompt_speech_mels_list = [], []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            log_mel = s3tokenizer.log_mel_spectrogram(audio)  # [num_mels, T]
+            prompt_speech_mels_list.append(log_mel)
+        prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
+        prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
+            prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
+        )
+        for i in range(len(prompt_speech_tokens)):
+            speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
+            prompt_speech_tokens_list.append(speech_tokens_i)
+        return prompt_speech_tokens_list
+
+    def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
+        spk_emb_for_flow = []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
+            spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
+            spk_emb = self.forward_spk_embedding(spk_feat)
+
+            spk_emb_for_flow.append(spk_emb)
+        spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
+        if self.dtype != torch.float32:
+            spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
+        return spk_emb_for_flow
+
+    def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
+        prompt_mels_for_flow = []
+        prompt_mels_lens_for_flow = []
+        for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
+            assert len(audio.shape) == 1
+            audio = audio.unsqueeze(0)
+            if sample_rate != 24000:
+                audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
+            mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0)  # [T, num_mels]
+            mel_len = mel.shape[0]
+            prompt_mels_for_flow.append(mel)
+            prompt_mels_lens_for_flow.append(mel_len)
+        prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
+            prompt_mels_for_flow, batch_first=True, padding_value=0
+        )  # [B, T', num_mels=80]
+        prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
+        return prompt_mels_for_flow, prompt_mels_lens_for_flow
+
+    def forward_flow(self, prompt_speech_tokens_list: list[list[int]],
+                     generated_speech_tokens_list: list[list[int]],
+                     prompt_mels_for_flow: torch.Tensor,
+                     prompt_mels_lens_for_flow: torch.Tensor,
+                     spk_emb_for_flow: torch.Tensor):
+        batch_size = prompt_mels_for_flow.shape[0]
+        flow_inputs = []
+        flow_inputs_lens = []
+        for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
+            flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
+            flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
+
+        flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
+        flow_inputs_lens = torch.tensor(flow_inputs_lens)
+
+        with torch.amp.autocast(self.device, dtype=torch.float16):
+            generated_mels, generated_mels_lens = self.flow.inference(
+                flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
+                prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
+            )
+
+        return generated_mels, generated_mels_lens
+
+    def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
+        batch_size = generated_mels.shape[0]
+        generated_wavs = []
+        for i in range(batch_size):
+            mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
+            wav, _ = self.hift(speech_feat=mel)
+            generated_wavs.append(wav)
+        return generated_wavs
+
+    @torch.inference_mode()
+    def forward(
+        self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
+    ):
+        assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
+
+        prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
+
+        generated_mels, generated_mels_lens = self.forward_flow(
+            prompt_speech_tokens_list, generated_speech_tokens_list,
+            prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+        )
+
+        generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
+        return generated_wavs
+
+    def prepare_prompt_audio(
+        self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
+    ):
+        assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
+
+        prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
+
+        prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
+
+        spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
+        return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+
+    def get_prompt_audio_cache_for_streaming_tts(
+        self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+    ):
+        assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
+        for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
+            prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
+        prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
+
+        cache = self.flow.setup_cache(
+            prompt_speech_tokens_tensor.to(self.device),
+            prompt_mels_for_flow.to(self.device),
+            spk_emb_for_flow.to(self.device),
+            n_timesteps=10
+        )
+        new_cache = {k: v.clone() for k, v in cache.items()}
+        # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
+        return new_cache
+
+    @torch.inference_mode()
+    def forward_streaming(
+        self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
+    ):
+        if speaker_id not in self.speaker_cache:
+            assert prompt_audio is not None, "prompt_audio is required for new speaker"
+            assert prompt_audio_sample_rate == 16000
+
+            prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
+
+            token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
+            prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
+            prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
+
+            prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
+
+            cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
+            self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
+
+        if request_id not in self.streaming_flow_cache:
+            self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
+            self.hift_cache_dict[request_id] = dict(
+                mel=torch.zeros(1, 80, 0, device='cuda'),
+                source=torch.zeros(1, 1, 0, device='cuda'),
+                speech=torch.zeros(1, 0, device='cuda'),
+            )
+
+        current_request_cache = self.streaming_flow_cache[request_id]
+
+        current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
+        generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
+
+        chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
+            token=generated_speech_tokens,
+            spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
+            cache=current_request_cache,
+            last_chunk=last_chunk,
+            n_timesteps=10,
+        )
+
+        self.streaming_flow_cache[request_id] = new_streaming_flow_cache
+
+        if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
+            self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
+                self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
+                self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
+            ], dim=4)
+
+        hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
+        hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
+        hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
+        mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
+
+        speech, source = self.hift(mel, hift_cache_source)
+
+        # overlap speech smooth
+        if hift_cache_speech.shape[-1] > 0:
+            speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
+
+        # update vocoder cache
+        self.hift_cache_dict[request_id] = dict(
+            mel=mel[..., -self.mel_cache_len:].clone().detach(),
+            source=source[:, :, -self.source_cache_len:].clone().detach(),
+            speech=speech[:, -self.source_cache_len:].clone().detach(),
+        )
+        if not last_chunk:
+            speech = speech[:, :-self.source_cache_len]
+
+        if last_chunk:
+            assert request_id in self.streaming_flow_cache
+            self.streaming_flow_cache.pop(request_id)
+            self.hift_cache_dict.pop(request_id)
+
+        return speech
+
+
+def collate_fn(batch):
+    ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
+    for item in batch:
+        generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
+        audio = torch.from_numpy(item['prompt_audio']['array']).float()
+        prompt_audios_list.append(audio)
+        prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
+        ids.append(item['id'])
+
+    return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--enable-trt", action="store_true")
+    parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
+    parser.add_argument("--batch-size", type=int, default=1)
+    parser.add_argument("--output-dir", type=str, default="generated_wavs")
+    parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
+    parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = get_args()
+    model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    dataset_name = "yuekai/seed_tts_cosy2"
+
+    dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
+
+    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
+
+    for _ in range(args.warmup):
+        start_time = time.time()
+        for batch in data_loader:
+            ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
+
+            generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
+
+            for id, wav in zip(ids, generated_wavs):
+                torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
+        end_time = time.time()
+        epoch_time = end_time - start_time
+        print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")

+ 69 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt

@@ -0,0 +1,69 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "token2wav_dit"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+    priority_levels: 10
+    default_priority_level: 10
+}
+
+parameters [
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "target_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+  },
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+  },
+  {
+    name: "finalize"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 99 - 10
runtime/triton_trtllm/offline_inference.py

@@ -28,7 +28,6 @@ import argparse
 import json
 import json
 import os
 import os
 import sys
 import sys
-from pathlib import Path
 
 
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
@@ -43,8 +42,9 @@ import soundfile as sf
 import s3tokenizer
 import s3tokenizer
 from functools import partial
 from functools import partial
 import time
 import time
-
-from token2wav import CosyVoice2_Token2Wav
+import requests
+import asyncio
+import httpx
 
 
 sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
 sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
 try:
 try:
@@ -53,6 +53,32 @@ except RuntimeError:
     pass
     pass
 
 
 
 
+async def send_request_async(client, url, payload):
+    response = await client.post(url, json=payload, timeout=None)
+    response.raise_for_status()
+    response_json = response.json()
+    return response_json['choices'][0]['message']['content']
+
+
+async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
+    async with httpx.AsyncClient() as client:
+        tasks = []
+        for chat in chats:
+            payload = {
+                "model": model_name,
+                "messages": chat,
+                "max_tokens": 2048,
+                "temperature": temperature,
+                "top_p": top_p,
+                "top_k": top_k,
+                "repetition_penalty": 1.1,
+                "stop": ["<|eos1|>", "<|eos|>"],
+                "stream": False,
+            }
+            tasks.append(send_request_async(client, api_base, payload))
+        return await asyncio.gather(*tasks)
+
+
 def extract_speech_ids(speech_tokens_str):
 def extract_speech_ids(speech_tokens_str):
     """Extract speech IDs from token strings like <|s_23456|>"""
     """Extract speech IDs from token strings like <|s_23456|>"""
     speech_ids = []
     speech_ids = []
@@ -149,7 +175,7 @@ def get_args():
         "--backend",
         "--backend",
         type=str,
         type=str,
         default="hf",
         default="hf",
-        choices=["hf", "trtllm", "vllm"],
+        choices=["hf", "trtllm", "vllm", "trtllm-serve"],
         help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
         help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
     )
     )
     parser.add_argument(
     parser.add_argument(
@@ -164,6 +190,18 @@ def get_args():
         default=0.6,
         default=0.6,
         help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
         help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
     )
     )
+    parser.add_argument(
+        "--openai-api-base",
+        type=str,
+        default="http://localhost:8000/v1/chat/completions",
+        help="OpenAI API base URL (for trtllm-serve backend)",
+    )
+    parser.add_argument(
+        "--openai-model-name",
+        type=str,
+        default="trt_engines_bfloat16",
+        help="Model name to use with OpenAI API (for trtllm-serve backend)",
+    )
     args = parser.parse_args()
     args = parser.parse_args()
     return args
     return args
 
 
@@ -180,6 +218,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
     input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
     input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
     prompt_text_after_apply_template_list = []
     prompt_text_after_apply_template_list = []
     mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
     mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
+    chat_list = []
     for _, item in enumerate(batch):
     for _, item in enumerate(batch):
         audio_processing_start_time = time.time()
         audio_processing_start_time = time.time()
         prompt_text, target_text = (
         prompt_text, target_text = (
@@ -237,6 +276,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
             {"role": "user", "content": full_text_list[i]},
             {"role": "user", "content": full_text_list[i]},
             {"role": "assistant", "content": prompt_audio_cosy2_id_str}
             {"role": "assistant", "content": prompt_audio_cosy2_id_str}
         ]
         ]
+        chat_list.append(chat)
 
 
         assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
         assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
 
 
@@ -265,6 +305,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
         "audio_processing_time": total_audio_processing_time,
         "audio_processing_time": total_audio_processing_time,
         "speech_tokenization_time": total_speech_tokenization_time,
         "speech_tokenization_time": total_speech_tokenization_time,
         "text_tokenization_time": total_text_tokenization_time,
         "text_tokenization_time": total_text_tokenization_time,
+        "chat_list": chat_list
     }
     }
 
 
 
 
@@ -318,9 +359,16 @@ def main(args):
     elif args.backend == "vllm":
     elif args.backend == "vllm":
         model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
         model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
         runner = None
         runner = None
+    elif args.backend == "trtllm-serve":
+        model = None
+        runner = None
     else:
     else:
         raise ValueError(f"Unsupported backend: {args.backend}")
         raise ValueError(f"Unsupported backend: {args.backend}")
-
+    if 'Step-Audio-2-mini' in args.token2wav_path:
+        from token2wav_dit import CosyVoice2_Token2Wav
+    else:
+        assert 'CosyVoice2-0.5B' in args.token2wav_path
+        from token2wav import CosyVoice2_Token2Wav
     token2wav_model = CosyVoice2_Token2Wav(
     token2wav_model = CosyVoice2_Token2Wav(
         model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
         model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
     )
     )
@@ -452,6 +500,35 @@ def main(args):
                     print(outputs)
                     print(outputs)
                     for j, output in enumerate(outputs):
                     for j, output in enumerate(outputs):
                         outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
                         outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
+                elif args.backend == "trtllm-serve":
+                    if args.batch_size > 1:
+                        outputs = asyncio.run(send_batch_requests_async(
+                            args.openai_api_base,
+                            args.openai_model_name,
+                            batch["chat_list"],
+                            args.temperature,
+                            args.top_p,
+                            args.top_k,
+                        ))
+                    else:
+                        outputs = []
+                        for chat in batch["chat_list"]:
+                            payload = {
+                                "model": args.openai_model_name,
+                                "messages": chat,
+                                "max_tokens": 2048,
+                                "temperature": args.temperature,
+                                "top_p": args.top_p,
+                                "top_k": args.top_k,
+                                "repetition_penalty": 1.1,
+                                "stop": ["<|eos1|>", "<|eos|>"],
+                                "stream": False,
+                            }
+                            response = requests.post(args.openai_api_base, json=payload)
+                            response.raise_for_status()
+                            response_json = response.json()
+                            generated_content = response_json['choices'][0]['message']['content']
+                            outputs.append(generated_content)
 
 
                 llm_end_time = time.time()
                 llm_end_time = time.time()
                 total_llm_time += (llm_end_time - llm_start_time)
                 total_llm_time += (llm_end_time - llm_start_time)
@@ -459,10 +536,21 @@ def main(args):
                 items_for_token_2wav = []
                 items_for_token_2wav = []
                 for i in range(len(batch["ids"])):
                 for i in range(len(batch["ids"])):
                     llm_post_processing_start_time = time.time()
                     llm_post_processing_start_time = time.time()
-                    input_length = len(batch["input_ids"][i])
-                    generated_ids = outputs[i][input_length:]
-                    speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
-                    speech_ids = extract_speech_ids(speech_tokens_str)
+                    if args.backend == "trtllm-serve":
+                        speech_tokens_str = outputs[i].strip().split('><')
+                        if len(speech_tokens_str) > 1:
+                            speech_tokens_str = [
+                                t if t.startswith('<') else '<' + t for t in speech_tokens_str
+                            ]
+                            speech_tokens_str = [
+                                t if t.endswith('>') else t + '>' for t in speech_tokens_str
+                            ]
+                        speech_ids = extract_speech_ids(speech_tokens_str)
+                    else:
+                        input_length = len(batch["input_ids"][i])
+                        generated_ids = outputs[i][input_length:]
+                        speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+                        speech_ids = extract_speech_ids(speech_tokens_str)
                     print(i, speech_ids)
                     print(i, speech_ids)
                     if len(speech_ids) == 0:
                     if len(speech_ids) == 0:
                         print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
                         print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
@@ -502,7 +590,6 @@ def main(args):
                         t2w_prompt_audios_list,
                         t2w_prompt_audios_list,
                         t2w_prompt_audios_sample_rate,
                         t2w_prompt_audios_sample_rate,
                     )
                     )
-                    torch.cuda.synchronize()
                     token2wav_end_time = time.time()
                     token2wav_end_time = time.time()
                     total_token2wav_time += (token2wav_end_time - token2wav_start_time)
                     total_token2wav_time += (token2wav_end_time - token2wav_start_time)
 
 
@@ -558,6 +645,8 @@ if __name__ == "__main__":
         from tensorrt_llm.runtime import ModelRunnerCpp
         from tensorrt_llm.runtime import ModelRunnerCpp
     elif args.backend == "hf":
     elif args.backend == "hf":
         from transformers import AutoModelForCausalLM
         from transformers import AutoModelForCausalLM
+    elif args.backend == "trtllm-serve":
+        pass
     else:
     else:
         raise ValueError(f"Unsupported backend: {args.backend}")
         raise ValueError(f"Unsupported backend: {args.backend}")
     main(args)
     main(args)

+ 225 - 0
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh

@@ -0,0 +1,225 @@
+#!/bin/bash
+# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
+export CUDA_VISIBLE_DEVICES=0
+cosyvoice_path=/workspace/CosyVoice
+stepaudio2_path=/workspace/Step-Audio2
+
+export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
+export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
+export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
+
+stage=$1
+stop_stage=$2
+
+huggingface_model_local_dir=./cosyvoice2_llm
+model_scope_model_local_dir=./CosyVoice2-0.5B
+step_audio_model_dir=./Step-Audio-2-mini
+
+trt_dtype=bfloat16
+trt_weights_dir=./trt_weights_${trt_dtype}
+trt_engines_dir=./trt_engines_${trt_dtype}
+
+model_repo=./model_repo_cosyvoice2_dit
+bls_instance_num=10
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+
+    echo "Cloning Step-Audio2-mini"
+    git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
+
+    echo "Cloning CosyVoice"
+    git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
+    cd $cosyvoice_path
+    git submodule update --init --recursive
+    cd runtime/triton_trtllm
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+    echo "Downloading CosyVoice2-0.5B"
+    # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
+    huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
+    modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
+
+    echo "Step-Audio2-mini"
+    huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
+    cd $step_audio_model_dir/token2wav
+    wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
+    wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
+    cd -
+fi
+
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+    echo "Converting checkpoint to TensorRT weights"
+    python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
+                                --output_dir $trt_weights_dir \
+                                --dtype $trt_dtype || exit 1
+
+    echo "Building TensorRT engines"
+    trtllm-build --checkpoint_dir $trt_weights_dir \
+                --output_dir $trt_engines_dir \
+                --max_batch_size 64 \
+                --max_num_tokens 32768 \
+                --gemm_plugin $trt_dtype || exit 1
+
+    echo "Testing TensorRT engines"
+    python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
+                    --tokenizer_dir $huggingface_model_local_dir \
+                    --top_k 50 --top_p 0.95 --temperature 0.8 \
+                    --engine_dir=$trt_engines_dir  || exit 1
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+    echo "Creating model repository async mode"
+    rm -rf $model_repo
+    mkdir -p $model_repo
+    cosyvoice2_dir="cosyvoice2_dit"
+    token2wav_dir="token2wav_dit"
+
+    cp -r ./model_repo/${cosyvoice2_dir} $model_repo
+    cp -r ./model_repo/${token2wav_dir} $model_repo
+    cp -r ./model_repo/audio_tokenizer $model_repo
+    cp -r ./model_repo/speaker_embedding $model_repo
+
+
+    ENGINE_PATH=$trt_engines_dir
+    MAX_QUEUE_DELAY_MICROSECONDS=0
+    MODEL_DIR=$model_scope_model_local_dir
+    LLM_TOKENIZER_DIR=$huggingface_model_local_dir
+    BLS_INSTANCE_NUM=$bls_instance_num
+    TRITON_MAX_BATCH_SIZE=1
+    DECOUPLED_MODE=True # Only streaming TTS mode is supported using Nvidia Triton for now
+    STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
+
+    python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+   echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
+   mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64  --kv_cache_free_gpu_memory_fraction 0.4 &
+   tritonserver --model-repository $model_repo --http-port 18000 &
+   wait
+    # Test using curl
+    # curl http://localhost:8000/v1/chat/completions \
+    #     -H "Content-Type: application/json" \
+    #     -d '{
+    #         "model": "",
+    #         "messages":[{"role": "user", "content": "Where is New York?"},
+    #                     {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
+    #         "max_tokens": 512,
+    #         "temperature": 0.8,
+    #         "top_p": 0.95,
+    #         "top_k": 50,
+    #         "stop": ["<|eos1|>"],
+    #         "repetition_penalty": 1.2,
+    #         "stream": false
+    #     }'
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+    echo "Running benchmark client"
+    num_task=4
+    mode=streaming
+    BLS_INSTANCE_NUM=$bls_instance_num
+
+    python3 client_grpc.py \
+        --server-addr localhost \
+        --server-port 8001 \
+        --model-name cosyvoice2_dit \
+        --num-tasks $num_task \
+        --mode $mode \
+        --huggingface-dataset yuekai/seed_tts_cosy2 \
+        --log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
+
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
+
+  datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
+  backend=trtllm # hf, trtllm, vllm, trtllm-serve
+
+  batch_sizes=(16)
+  token2wav_batch_size=1
+
+  for batch_size in ${batch_sizes[@]}; do
+    for dataset in ${datasets[@]}; do
+    output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
+    CUDA_VISIBLE_DEVICES=1 \
+        python3 offline_inference.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $step_audio_model_dir/token2wav \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+    done
+  done
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+   echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
+   export CUDA_VISIBLE_DEVICES=1
+   # Note: Using pre-computed cosyvoice2 tokens
+   python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
+   # Offline Token2wav inference
+   python3 token2wav_dit.py --enable-trt
+fi
+
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+   echo "Disaggregated Server: LLM and Token2wav on different GPUs"
+   echo "Starting LLM server on GPU 0"
+   export CUDA_VISIBLE_DEVICES=0
+   mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 64  --kv_cache_free_gpu_memory_fraction 0.4 &
+   echo "Starting Token2wav server on GPUs 1-3"
+   Token2wav_num_gpus=3
+   http_port=17000
+   grpc_port=18000
+   metrics_port=16000
+   for i in $(seq 0 $(($Token2wav_num_gpus - 1))); do
+       echo "Starting server on GPU $i"
+       http_port=$((http_port + 1))
+       grpc_port=$((grpc_port + 1))
+       metrics_port=$((metrics_port + 1))
+       # Two instances of Token2wav server on the same GPU
+       CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
+       http_port=$((http_port + 1))
+       grpc_port=$((grpc_port + 1))
+       metrics_port=$((metrics_port + 1))
+       CUDA_VISIBLE_DEVICES=$(($i + 1)) tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
+   done
+   wait
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+    echo "Running benchmark client for Disaggregated Server"
+    per_gpu_instances=2
+    mode=streaming
+    BLS_INSTANCE_NUM=$bls_instance_num
+    Token2wav_num_gpus=(1 2 3)
+    concurrent_tasks=(1 2 3 4 5 6)
+    for n_gpu in ${Token2wav_num_gpus[@]}; do
+        echo "Test 1 GPU for LLM server and $n_gpu GPUs for Token2wav servers"
+        for concurrent_task in ${concurrent_tasks[@]}; do
+            num_instances=$((per_gpu_instances * n_gpu))
+            for i in $(seq 1 $num_instances); do
+                port=$(($i + 18000))
+                python3 client_grpc.py \
+                    --server-addr localhost \
+                    --server-port $port \
+                    --model-name cosyvoice2_dit \
+                    --num-tasks $concurrent_task \
+                    --mode $mode \
+                    --huggingface-dataset yuekai/seed_tts_cosy2 \
+                    --log-dir ./log_disagg_concurrent_tasks_${concurrent_task}_per_instance_total_token2wav_instances_${num_instances}_port_${port} &
+            done
+            wait
+        done
+    done
+fi

+ 0 - 5
runtime/triton_trtllm/scripts/test_llm.py

@@ -15,11 +15,6 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import argparse
 import argparse
-import ast
-import csv
-import os
-from pathlib import Path
-from typing import List, Optional
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch

+ 122 - 0
runtime/triton_trtllm/streaming_inference.py

@@ -0,0 +1,122 @@
+import torch
+import os
+import argparse
+from datasets import load_dataset
+from torch.utils.data import DataLoader
+import numpy as np
+import torchaudio
+import time
+from token2wav_dit import CosyVoice2_Token2Wav
+import soundfile as sf
+
+
+def collate_fn(batch):
+    ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
+    prompt_speech_tokens_list, prompt_text_list = [], []
+    for item in batch:
+        generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
+        audio = torch.from_numpy(item['prompt_audio']['array']).float()
+        prompt_audios_list.append(audio)
+        prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
+        ids.append(item['id'])
+        prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
+        prompt_text_list.append(item['prompt_text'])
+
+    return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--enable-trt", action="store_true")
+    parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
+    parser.add_argument("--batch-size", type=int, default=1)
+    parser.add_argument("--output-dir", type=str, default="generated_wavs")
+    parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
+    parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
+    parser.add_argument("--strategy", type=str, default="equal", choices=["equal", "exponential"])
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = get_args()
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+
+    dataset_name = args.dataset_name
+    dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
+    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
+
+    token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
+
+    CHUNK_SIZE = 25
+    token_frame_rate = 25
+    OVERLAP_SIZE = 0
+
+    warmup_times = 3
+    for _ in range(warmup_times):
+        start_time = time.time()
+        total_forward_count = 0
+        for batch in data_loader:
+            tts_speech_list = []
+            ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
+
+            id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
+
+            assert prompt_audio_sample_rate == 16000
+
+            prompt_text = prompt_text_list[0]
+            prompt_speech_tokens = prompt_speech_tokens_list[0]
+
+            semantic_token_ids_arr, token_offset = [], 0
+            flow_prompt_speech_token_len = len(prompt_speech_tokens)
+
+            buffer = generated_speech_tokens
+            output_wavs = []
+            chunk_index = 0
+            while True:
+                if args.strategy == "equal":
+                    this_chunk_size = CHUNK_SIZE
+                elif args.strategy == "exponential":
+                    this_chunk_size = token_frame_rate * (2 ** chunk_index)
+
+                if len(buffer) >= this_chunk_size + token2wav_model.flow.pre_lookahead_len:
+                    wavs = token2wav_model.forward_streaming(
+                        buffer[:this_chunk_size + token2wav_model.flow.pre_lookahead_len],
+                        False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio,
+                        prompt_audio_sample_rate=prompt_audio_sample_rate
+                    )
+                    buffer = buffer[this_chunk_size - OVERLAP_SIZE:]
+
+                    output_wavs.append(wavs)
+                    total_forward_count += 1
+                    chunk_index += 1
+
+                else:
+                    wavs = token2wav_model.forward_streaming(
+                        buffer, True, request_id=id, speaker_id=f"{id}",
+                        prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate
+                    )
+                    output_wavs.append(wavs)
+                    total_forward_count += 1
+                    # chunk_index += 1
+                    break
+
+            for i, wav in enumerate(output_wavs):
+                output_wavs[i] = wav.cpu().numpy().squeeze()
+
+            audios = output_wavs
+            reconstructed_audio = np.concatenate(audios)
+            sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
+
+        end_time = time.time()
+
+        if _ == 0:
+            token2wav_model.speaker_cache = {}
+            print(f"Warmup time: {end_time - start_time} seconds")
+            print("clear speaker cache")
+        elif _ == 1:
+            print(f"Cost time without speaker cache: {end_time - start_time} seconds")
+        else:
+            print(f"Cost time with speaker cache: {end_time - start_time} seconds")
+            print(f"Total flow matching forward calls: {total_forward_count}")

+ 1 - 0
runtime/triton_trtllm/token2wav_dit.py

@@ -0,0 +1 @@
+model_repo/token2wav_dit/1/token2wav_dit.py