Browse Source

Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main

lyuxiang.lx 2 months ago
parent
commit
4d60ff6abc
28 changed files with 3991 additions and 178 deletions
  1. 12 1
      README.md
  2. 1 1
      docker/Dockerfile
  3. 6 0
      examples/grpo/cosyvoice2/Dockerfile
  4. 125 0
      examples/grpo/cosyvoice2/README.md
  5. 71 0
      examples/grpo/cosyvoice2/huggingface_to_pretrained.py
  6. 397 0
      examples/grpo/cosyvoice2/infer_dataset.py
  7. 86 0
      examples/grpo/cosyvoice2/prepare_data.py
  8. 135 0
      examples/grpo/cosyvoice2/pretrained_to_huggingface.py
  9. 31 0
      examples/grpo/cosyvoice2/requirements.txt
  10. 233 0
      examples/grpo/cosyvoice2/reward_tts.py
  11. 159 0
      examples/grpo/cosyvoice2/run.sh
  12. 33 0
      examples/grpo/cosyvoice2/scripts/compute_wer.sh
  13. 756 0
      examples/grpo/cosyvoice2/scripts/offline-decode-files.py
  14. 346 0
      examples/grpo/cosyvoice2/token2wav_asr_server.py
  15. 1 1
      runtime/python/Dockerfile
  16. 92 37
      runtime/triton_trtllm/README.md
  17. 55 33
      runtime/triton_trtllm/client_grpc.py
  18. 1 1
      runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
  19. 1 1
      runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt
  20. 166 57
      runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py
  21. 5 2
      runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt
  22. 153 0
      runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py
  23. 48 0
      runtime/triton_trtllm/model_repo/speaker_embedding/config.pbtxt
  24. 116 33
      runtime/triton_trtllm/model_repo/token2wav/1/model.py
  25. 18 1
      runtime/triton_trtllm/model_repo/token2wav/config.pbtxt
  26. 563 0
      runtime/triton_trtllm/offline_inference.py
  27. 46 10
      runtime/triton_trtllm/run.sh
  28. 335 0
      runtime/triton_trtllm/token2wav.py

+ 12 - 1
README.md

@@ -31,7 +31,7 @@
 
 - [x] 2025/08
 
-    - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support
+    - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
 
 - [x] 2025/07
 
@@ -246,6 +246,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o
 cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
 ```
 
+#### Using Nvidia TensorRT-LLM for deployment
+
+Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
+To quick start:
+
+``` sh
+cd runtime/triton_trtllm
+docker compose up -d
+```
+For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
+
 ## Discussion & Communication
 
 You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).

+ 1 - 1
docker/Dockerfile

@@ -46,6 +46,6 @@ RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
 
 RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
 RUN conda activate ${VENV} && cd CosyVoice && \
-    pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
+    pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
 
 WORKDIR /workspace/CosyVoice

+ 6 - 0
examples/grpo/cosyvoice2/Dockerfile

@@ -0,0 +1,6 @@
+FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
+COPY requirements.txt /myworkspace/requirements.txt
+RUN pip install -r /myworkspace/requirements.txt
+RUN pip install -U nvidia-pytriton
+RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
+RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .

+ 125 - 0
examples/grpo/cosyvoice2/README.md

@@ -0,0 +1,125 @@
+# CosyVoice2 LLM Reinforcement Learning Recipe
+
+This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
+
+## Table of Contents
+
+- [Environment Setup](#environment-setup)
+- [Data Preparation](#data-preparation)
+- [Reward Function & ASR Server](#reward-function--asr-server)
+- [Training](#training)
+- [Evaluation](#evaluation)
+- [Export Model](#export-model)
+- [Results](#results)
+- [Acknowledgement](#acknowledgement)
+
+## Environment Setup
+We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
+```bash
+docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
+```
+If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
+
+## Data Preparation
+
+`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
+
+```jsonc
+{
+  "text": "An example sentence to be synthesized."
+}
+```
+You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
+
+Stage `0` converts raw JSONL files into the parquet format expected by veRL:
+
+```bash
+bash run.sh 0 0
+```
+Create two JSONL files—`train.jsonl` and `test.jsonl`.  
+The script will then generate two Parquet files:
+
+```
+data/parquet_tiny/train.parquet
+data/parquet_tiny/test.parquet
+```
+
+Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
+
+
+## Reward Function & ASR Server
+
+To compute rewards, we run a lightweight server that:
+
+1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
+2. Transcribes the waveform with **SenseVoice** ASR.
+3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
+
+Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
+
+```bash
+bash run.sh 1 1
+# Triton server listens on ports 8000/8001/8002
+```
+
+The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
+
+## Training
+
+Run stage `2` to start GRPO training:
+
+```bash
+bash run.sh 2 2
+```
+
+Key CLI arguments passed to `verl.trainer.main_ppo`:
+
+* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO.
+* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
+* `custom_reward_function.path=reward_tts.py` – custom reward function described above.
+
+Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
+> [!TIP]
+> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
+
+## Evaluation
+
+After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
+
+```bash
+bash run.sh 3 3   # merges weights into $llm_path/merged_hf_model
+```
+
+You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
+
+```bash
+bash run.sh 4 4
+```
+
+This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
+
+> [!TIP]
+> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
+
+## Export Model
+
+To use the RL-trained model with the official CosyVoice repository:
+
+```bash
+bash run.sh 5 5
+```
+
+The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
+> [!TIP]
+>  However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. 
+
+## Results
+
+| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
+|-------|------------------------|------------------------------|---------|
+| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
+| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
+
+## Acknowledgement
+
+This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).

+ 71 - 0
examples/grpo/cosyvoice2/huggingface_to_pretrained.py

@@ -0,0 +1,71 @@
+
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
+"""
+from argparse import ArgumentParser
+import torch
+from safetensors import safe_open
+from transformers import AutoTokenizer
+
+
+def get_args():
+    parser = ArgumentParser()
+
+    parser.add_argument(
+        "--hf-cosyvoice2-llm-path",
+        type=str,
+        default=None,
+        help="The RL trained CosyVoice2 model path in HuggingFace format",
+    )
+    parser.add_argument(
+        "--output-path",
+        type=str,
+        default="./llm.pt",
+        help="The path to save the llm.pt",
+    )
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == "__main__":
+    args = get_args()
+
+    tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
+    speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
+    cosyvoice2_token_size = 6561 + 3
+    llm_embedding_vocab_size = 2
+
+    hf_tensors = {}
+    with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
+        for k in f.keys():
+            if k.startswith("lm_head.bias"):
+                # RL trained model disable bias for lm_head
+                continue
+            new_k = "llm.model." + k
+            hf_tensors[new_k] = f.get_tensor(k)
+            if k.startswith("lm_head"):
+                hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
+                hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
+            if k.startswith("model.embed_tokens"):
+                hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
+                hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
+
+        # use tie_word_embeddings=True
+        hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
+        hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
+
+    torch.save(hf_tensors, args.output_path)

+ 397 - 0
examples/grpo/cosyvoice2/infer_dataset.py

@@ -0,0 +1,397 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+dataset=zero_shot_zh
+output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
+
+token2wav_path=/workspace/CosyVoice2-0.5B
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+torchrun --nproc_per_node=8 \
+    infer_dataset.py \
+    --output-dir $output_dir \
+    --llm-model-name-or-path $llm_path/merged_hf_model \
+    --token2wav-path $token2wav_path \
+    --split-name ${dataset} || exit 1
+"""
+
+import argparse
+import json
+import os
+import sys
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torchaudio
+from cosyvoice.cli.cosyvoice import CosyVoice2
+from cosyvoice.utils.file_utils import load_wav
+from datasets import load_dataset
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from torch.utils.data import DataLoader, Dataset, DistributedSampler
+from tqdm import tqdm
+import soundfile as sf
+import s3tokenizer
+from functools import partial
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+try:
+    torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+    pass
+
+
+TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
+
+
+def audio_decode_cosyvoice2(
+    audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
+):
+    """
+    Generate audio from tokens with optional tone and prompt embedding.
+    """
+    model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
+        "empty", prompt_text, prompt_speech_16k, 24000
+    )
+    tts_mel, _ = codec_decoder.model.flow.inference(
+        token=audio_tokens.to(codec_decoder.model.device),
+        token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+            codec_decoder.model.device
+        ),
+        prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
+            codec_decoder.model.device
+        ),
+        prompt_token_len=torch.tensor(
+            [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
+        ).to(codec_decoder.model.device),
+        prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
+            codec_decoder.model.device
+        ),
+        prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
+            codec_decoder.model.device
+        ),
+        embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
+        finalize=True,
+    )
+
+    audio_hat, _ = codec_decoder.model.hift.inference(
+        speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+    )
+
+    return audio_hat
+
+
+def extract_speech_ids(speech_tokens_str):
+    """Extract speech IDs from token strings like <|s_23456|>"""
+    speech_ids = []
+    for token_str in speech_tokens_str:
+        if token_str.startswith('<|s_') and token_str.endswith('|>'):
+            num_str = token_str[4:-2]
+            num = int(num_str)
+            speech_ids.append(num)
+        else:
+            print(f"Unexpected token: {token_str}")
+    return speech_ids
+
+
+def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
+    """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
+    speech_id_str = ""
+    for token in cosy2_tokens:
+        speech_id_str += f"<|s_{token}|>"
+    return speech_id_str
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
+    parser.add_argument(
+        "--split-name",
+        type=str,
+        default="wenetspeech4tts",
+        help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
+    )
+    parser.add_argument(
+        "--output-dir", required=True, type=str, help="dir to save result"
+    )
+    parser.add_argument(
+        "--batch-size",
+        default=1,
+        type=int,
+        help="batch size (per-device) for inference",
+    )
+    parser.add_argument(
+        "--num-workers", type=int, default=1, help="workers for dataloader"
+    )
+    parser.add_argument(
+        "--prefetch", type=int, default=5, help="prefetch for dataloader"
+    )
+    parser.add_argument(
+        "--llm-model-name-or-path",
+        required=True,
+        type=str,
+        help="LLM model path (includes both model and tokenizer)",
+    )
+    parser.add_argument(
+        "--token2wav-path",
+        required=True,
+        type=str,
+        help="CosyVoice2 token2wav model path",
+    )
+    parser.add_argument(
+        "--prompt-text",
+        type=str,
+        default=None,
+        help="The prompt text for CosyVoice2",
+    )
+    parser.add_argument(
+        "--prompt-speech-path",
+        type=str,
+        default=None,
+        help="The path to the prompt speech for CosyVoice2",
+    )
+    parser.add_argument(
+        "--top-p",
+        type=float,
+        default=0.95,
+        help="top p for sampling",
+    )
+    parser.add_argument(
+        "--temperature",
+        type=float,
+        default=0.8,
+        help="temperature for sampling",
+    )
+    parser.add_argument(
+        "--top-k",
+        type=int,
+        default=50,
+        help="top k for sampling",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def data_collator(batch, tokenizer, s3_tokenizer):
+    """Simplified data collator for batch_size=1 processing"""
+    target_sample_rate = 16000  # CosyVoice2 uses 16kHz for prompt audio
+    device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
+    input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
+    mels, prompt_audio_cosy2tokens_list = [], []
+    for item in batch:
+        prompt_text, target_text = (
+            item["prompt_text"],
+            item["target_text"],
+        )
+        prompt_text_list.append(prompt_text)
+        # Combine prompt and target text
+        full_text = prompt_text + target_text
+
+        # get prompt audio for CosyVoice2 (convert to 16kHz)
+        ref_audio_org, ref_sr = (
+            item["prompt_audio"]["array"],
+            item["prompt_audio"]["sampling_rate"],
+        )
+        ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
+        # ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
+        print(ref_audio_org.shape)
+
+        if ref_sr != target_sample_rate:
+            resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
+            ref_audio = resampler(ref_audio_org)
+        else:
+            ref_audio = ref_audio_org
+
+        prompt_audio_list.append(ref_audio)
+
+        if "prompt_audio_cosy2_tokens" in item:
+            prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
+            prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
+        else:
+            # convert to float first
+            mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
+
+    if len(mels) > 0:
+        mels, mels_lens = s3tokenizer.padding(mels)
+        codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
+        for i in range(len(codes)):
+            prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
+    for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
+        prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
+        # Create chat template for LLM generation
+        chat = [
+            {"role": "user", "content": full_text},
+            {"role": "assistant", "content": prompt_audio_cosy2_id_str}
+        ]
+        if 'system' in tokenizer.chat_template:
+            tokenizer.chat_template = TEMPLATE
+        input_ids = tokenizer.apply_chat_template(
+            chat,
+            tokenize=True,
+            return_tensors='pt',
+            continue_final_message=True
+        )
+        input_ids_list.append(input_ids.squeeze(0))
+
+    # For batch_size=1, no need to pad
+    if len(input_ids_list) == 1:
+        input_ids = input_ids_list[0].unsqueeze(0)
+    else:
+        # Handle batch > 1 if needed
+        max_len = max([len(input_ids) for input_ids in input_ids_list])
+        input_ids_list = [
+            torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
+            for input_ids in input_ids_list
+        ]
+        input_ids = torch.stack(input_ids_list)
+
+    ids = [item["id"] for item in batch]
+
+    return {
+        "input_ids": input_ids,
+        "ids": ids,
+        "prompt_text": prompt_text_list,
+        "prompt_audio_list": prompt_audio_list,
+    }
+
+
+def init_distributed():
+    world_size = int(os.environ.get("WORLD_SIZE", 1))
+    local_rank = int(os.environ.get("LOCAL_RANK", 0))
+    rank = int(os.environ.get("RANK", 0))
+    print(
+        "Inference on multiple gpus, this gpu {}".format(local_rank)
+        + ", rank {}, world_size {}".format(rank, world_size)
+    )
+    torch.cuda.set_device(local_rank)
+    dist.init_process_group("nccl")
+    return world_size, local_rank, rank
+
+
+def main():
+    args = get_args()
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    assert torch.cuda.is_available()
+    world_size, local_rank, rank = init_distributed()
+    device = torch.device(f"cuda:{local_rank}")
+
+    # Load LLM model and tokenizer directly
+    tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
+    model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
+    model.eval()
+    model.to(device)
+
+    cosyvoice_codec = CosyVoice2(
+        args.token2wav_path, load_jit=True, load_trt=True, fp16=True
+    )
+    if args.prompt_speech_path:
+        prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
+    else:
+        prompt_speech_16k = None
+    s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
+    dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
+    dataset = load_dataset(
+        dataset_name,
+        split=args.split_name,
+        trust_remote_code=True,
+    )
+
+    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
+
+    dataloader = DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        sampler=sampler,
+        shuffle=False,
+        num_workers=args.num_workers,
+        prefetch_factor=args.prefetch,
+        collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
+    )
+
+    total_steps = len(dataset)
+
+    if rank == 0:
+        progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
+
+    for batch in dataloader:
+        with torch.no_grad():
+            input_ids = batch["input_ids"].to(device)
+
+            # Generate speech tokens using LLM
+            outputs = model.generate(
+                input_ids,
+                max_new_tokens=2048,  # Max length for generation
+                do_sample=True,
+                top_p=args.top_p,
+                temperature=args.temperature,
+                top_k=args.top_k,
+            )
+
+            # Process each sample in the batch
+            for i in range(len(batch["ids"])):
+                # Extract generated tokens (excluding input)
+                input_length = input_ids[i].shape[0]
+                generated_ids = outputs[i][input_length:-1]  # Remove last token if needed
+                speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+
+                # Extract speech IDs from token strings like <|s_23456|>
+                speech_ids = extract_speech_ids(speech_tokens_str)
+
+                if len(speech_ids) == 0:
+                    print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
+                    continue
+
+                # Convert to tensor for CosyVoice2
+                audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
+
+                if args.prompt_text is not None:
+                    current_prompt_text = args.prompt_text
+                    current_prompt_audio = prompt_speech_16k
+                else:
+                    current_prompt_text = batch["prompt_text"][i]
+                    current_prompt_audio = batch["prompt_audio_list"][i]
+
+                if current_prompt_audio is not None:
+                    # Generate audio using CosyVoice2
+                    audio_hat = audio_decode_cosyvoice2(
+                        audio_tokens,
+                        current_prompt_text,
+                        current_prompt_audio,
+                        cosyvoice_codec,
+                    )
+
+                    # Convert to numpy and save
+                    generated_wave = audio_hat.squeeze(0).cpu().numpy()
+                    target_sample_rate = 24000
+
+                    utt = batch["ids"][i]
+                    sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
+
+                    print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
+                else:
+                    print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
+
+        if rank == 0:
+            progress_bar.update(world_size * len(batch["ids"]))
+
+    if rank == 0:
+        progress_bar.close()
+
+    dist.barrier()
+    dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+    main()

+ 86 - 0
examples/grpo/cosyvoice2/prepare_data.py

@@ -0,0 +1,86 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocess the Text to Speech dataset to parquet format
+"""
+
+import argparse
+import os
+import re
+
+import datasets
+
+from verl.utils.hdfs_io import copy, makedirs
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
+    parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
+    parser.add_argument("--local_dir", default=None, required=True)
+    parser.add_argument("--hdfs_dir", default=None)
+
+    args = parser.parse_args()
+
+    # Load datasets from local JSON files
+    train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
+    test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
+
+    # add a row to each data item that represents a unique id
+    def make_map_fn(split):
+        def process_fn(example, idx):
+            text = example.pop("text")
+
+            # use cosyvoice2 official huggingface compatible checkpoint template
+            question = text
+            answer = ""
+
+            data = {
+                "data_source": f"{args.train_file}_{args.test_file}",  # Use file names as data source
+                "prompt": [
+                    {
+                        "role": "user",
+                        "content": question,
+                    },
+                    {
+                        "role": "assistant",
+                        "content": answer,
+                    },
+                ],
+                "ability": "text-to-speech",
+                "reward_model": {"style": "rule", "ground_truth": text},
+                "extra_info": {
+                    "split": split,
+                    "index": idx,
+                    "text": text,
+                },
+            }
+            return data
+
+        return process_fn
+
+    train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
+    test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
+
+    local_dir = args.local_dir
+    hdfs_dir = args.hdfs_dir
+
+    print(train_dataset)
+    print(test_dataset)
+    train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
+    test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
+
+    if hdfs_dir is not None:
+        makedirs(hdfs_dir)
+
+        copy(src=local_dir, dst=hdfs_dir)

+ 135 - 0
examples/grpo/cosyvoice2/pretrained_to_huggingface.py

@@ -0,0 +1,135 @@
+#!/usr/bin/env python3
+
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage: Instruct TTS
+  python3 infer.py \
+    --token2wav-path /workspace/CosyVoice2-0.5B \
+    --prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
+    --prompt-speech-path ./assets/prompt_audio.wav \
+    --model-path ./transformers_cosyvoice2_llm \
+    --input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
+"""
+from cosyvoice.cli.cosyvoice import CosyVoice2
+import sys
+from argparse import ArgumentParser
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import torch
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+
+
+def get_args():
+    parser = ArgumentParser()
+
+    parser.add_argument(
+        "--pretrained-cosyvoice2-path",
+        type=str,
+        default="/workspace/CosyVoice2-0.5B",
+        help="Token2Wav path, default to %(default)r",
+    )
+    parser.add_argument(
+        "--save-path",
+        type=str,
+        default='./transformers_cosyvoice2_llm',
+        help="The path to save the model",
+    )
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == "__main__":
+    args = get_args()
+    cosy2_model = CosyVoice2(
+        args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
+    )
+
+    llm = cosy2_model.model.llm.llm.model
+
+    speech_embedding = cosy2_model.model.llm.speech_embedding
+    llm_decoder = cosy2_model.model.llm.llm_decoder
+    llm_embedding = cosy2_model.model.llm.llm_embedding
+
+    tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
+    special_tokens = {
+        'eos_token': '<|endoftext|>',
+        'pad_token': '<|endoftext|>',
+        'additional_special_tokens': [
+            '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
+            '[breath]', '<strong>', '</strong>', '[noise]',
+            '[laughter]', '[cough]', '[clucking]', '[accent]',
+            '[quick_breath]',
+            "<laughter>", "</laughter>",
+            "[hissing]", "[sigh]", "[vocalized-noise]",
+            "[lipsmack]", "[mn]"
+        ]
+    }
+    tokenizer.add_special_tokens(special_tokens)
+
+    original_tokenizer_vocab_size = len(tokenizer)
+    cosyvoice2_token_size = 6561
+    new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
+        "<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
+    ]
+    num_added_tokens = tokenizer.add_tokens(new_tokens)
+
+    llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
+    vocab_size = llm.get_input_embeddings().weight.shape[0]
+
+    feature_size = speech_embedding.embedding_dim
+    new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
+
+    with torch.no_grad():
+        # set the weight and bias of the new lm_head to 0
+        new_lm_head.weight.data.zero_()
+        # make bias value -inf
+        new_lm_head.bias.data.fill_(-float('inf'))
+        new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
+        new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
+
+    llm.lm_head = new_lm_head
+    input_embeddings = llm.get_input_embeddings()
+
+    with torch.no_grad():
+        input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
+        input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
+
+    eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
+                     original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
+                     original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
+    llm.generation_config.eos_token_id = eos_token_ids
+    llm.generation_config.temperature = 1.0
+    llm.generation_config.top_p = 0.8
+    llm.generation_config.top_k = 25
+
+    llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
+    llm.config.vocab_size = vocab_size
+    llm.config.tie_word_embeddings = False
+    llm.config.use_bias = True
+    llm.to(torch.bfloat16)
+    llm.save_pretrained(args.save_path)
+
+    TEMPLATE = (
+        "{%- for message in messages %}"
+        "{%- if message['role'] == 'user' %}"
+        "{{- '<|sos|>' + message['content'] + '<|task_id|>' }}"
+        "{%- elif message['role'] == 'assistant' %}"
+        "{{- message['content']}}"
+        "{%- endif %}"
+        "{%- endfor %}"
+    )
+    tokenizer.chat_template = TEMPLATE
+    tokenizer.save_pretrained(args.save_path)

+ 31 - 0
examples/grpo/cosyvoice2/requirements.txt

@@ -0,0 +1,31 @@
+conformer==0.3.2
+diffusers==0.29.0
+gdown==5.1.0
+gradio
+hydra-core==1.3.2
+HyperPyYAML==1.2.2
+inflect==7.3.1
+librosa==0.10.2
+lightning==2.2.4
+matplotlib==3.7.5
+modelscope==1.15.0
+networkx==3.1
+omegaconf==2.3.0
+onnx==1.16.0
+onnxruntime-gpu==1.18.0
+protobuf==4.25
+pydantic==2.7.0
+pyworld==0.3.4
+rich==13.7.1
+soundfile==0.12.1
+tensorboard==2.14.0
+wget==3.2
+WeTextProcessing==1.0.3
+s3tokenizer
+tensorrt
+sherpa_onnx
+jiwer
+zhon
+numpy==1.25.2
+pypinyin
+openai-whisper

+ 233 - 0
examples/grpo/cosyvoice2/reward_tts.py

@@ -0,0 +1,233 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Reward calculation for CosyVoice2-0.5B.
+"""
+
+from __future__ import annotations
+
+import re
+import json
+import time
+import argparse
+from typing import List
+
+import numpy as np
+import requests
+
+
+REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
+
+
+def _parse_ids(token_str: str) -> List[int]:
+    return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
+
+
+def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
+    """Send token IDs and ground-truth text to the Triton server and get reward."""
+
+    tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
+    lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
+
+    gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
+
+    payload = {
+        "inputs": [
+            {
+                "name": "TOKENS",
+                "shape": list(tokens_arr.shape),
+                "datatype": "INT32",
+                "data": tokens_arr.tolist(),
+            },
+            {
+                "name": "TOKEN_LENS",
+                "shape": list(lens_arr.shape),
+                "datatype": "INT32",
+                "data": lens_arr.tolist(),
+            },
+            {
+                "name": "GT_TEXT",
+                "shape": [1, 1],
+                "datatype": "BYTES",
+                "data": [ground_truth],
+            },
+        ]
+    }
+    rsp = requests.post(
+        REWARD_SERVER_URL,
+        headers={"Content-Type": "application/json"},
+        json=payload,
+        timeout=timeout,
+        verify=False,
+        params={"request_id": "0"},
+    )
+    rsp.raise_for_status()
+    result = rsp.json()
+
+    try:
+        # Reward is returned as the first output
+        return float(result["outputs"][0]["data"][0])
+    except (KeyError, IndexError, TypeError):
+        return 0.0
+
+
+def compute_score(
+    data_source: str,
+    solution_str: str,
+    ground_truth: str,
+    extra_info: dict | None = None,
+    *,
+    debug_dump: bool = False,
+) -> float:
+    """Return reward in [0, 1] using the Triton ASR service.
+
+    The reward is based on the pinyin-level WER between the ASR transcript
+    produced from *solution_str* and the provided *ground_truth* text.
+    """
+
+    # Decode token IDs
+    ids = _parse_ids(solution_str)
+
+    # Query remote server for reward
+    try:
+        reward = _remote_reward(ids, ground_truth)
+    except Exception as e:
+        reward = 0.0
+
+    if debug_dump:
+        print(
+            f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
+        )
+
+    return reward
+
+
+# CLI quick test
+if __name__ == "__main__":
+    import sys
+
+    def get_args():
+        """Parse command line arguments."""
+        parser = argparse.ArgumentParser(
+            description="Test TTS CER scoring with data from JSONL file",
+            formatter_class=argparse.ArgumentDefaultsHelpFormatter
+        )
+
+        parser.add_argument(
+            "--input", "-i",
+            type=str,
+            default="data/emilia_zh-cosy-tiny-test.jsonl",
+            help="Path to input JSONL file"
+        )
+
+        parser.add_argument(
+            "--max-samples", "-n",
+            type=int,
+            default=None,
+            help="Maximum number of samples to process (default: all)"
+        )
+
+        parser.add_argument(
+            "--no-interactive",
+            action="store_true",
+            help="Run in non-interactive mode (process all samples without prompts)"
+        )
+
+        parser.add_argument(
+            "--debug",
+            action="store_true",
+            help="Enable debug mode"
+        )
+
+        return parser.parse_args()
+
+    def load_jsonl(file_path: str):
+        """Load data from jsonl file."""
+        data = []
+        with open(file_path, 'r', encoding='utf-8') as f:
+            for line in f:
+                data.append(json.loads(line.strip()))
+        return data
+
+    def code_to_solution_str(code_list: List[int]) -> str:
+        """Convert code list to solution string format."""
+        return ''.join([f"<|s_{code}|>" for code in code_list])
+
+    # Parse command line arguments
+    args = get_args()
+
+    try:
+        # Load data from jsonl file
+        print(f"Loading data from: {args.input}")
+        data_list = load_jsonl(args.input)
+        print(f"Loaded {len(data_list)} samples")
+
+        # Limit samples if specified
+        if args.max_samples is not None:
+            data_list = data_list[:args.max_samples]
+            print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
+
+        # Process each sample
+        begin_time = time.time()
+        for i, sample in enumerate(data_list):
+            print(f"\n--- Sample {i+1}/{len(data_list)} ---")
+            print(f"Index: {sample.get('index', 'unknown')}")
+            print(f"Text: {sample['text']}")
+
+            # Extract required fields
+            code_list = sample['code']
+            ground_truth = sample['text']
+            data_source = sample.get('index', f'sample_{i}')  # Use index as data_source
+
+            # Convert code list to solution string
+            solution_str = code_to_solution_str(code_list)
+            print(f"Solution tokens: {len(code_list)} tokens")
+            if args.debug:
+                print(f"Solution string: {solution_str}")
+            else:
+                print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
+
+            # Call compute_score function
+            try:
+                score = compute_score(
+                    data_source=data_source,
+                    solution_str=solution_str,
+                    ground_truth=ground_truth,
+                    extra_info=None,
+                    debug_dump=args.debug
+                )
+                print(f"Final Score: {score:.4f}")
+            except Exception as e:
+                print(f"Error computing score: {e}")
+
+            # Ask user if they want to continue (for interactive mode)
+            if not args.no_interactive and i < len(data_list) - 1:
+                try:
+                    response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
+                    if response == 'q':
+                        break
+                except KeyboardInterrupt:
+                    print("\nStopped by user")
+                    break
+
+        print(f"\nProcessed {min(i+1, len(data_list))} samples")
+        end_time = time.time()
+        print(f"Time taken: {end_time - begin_time} seconds")
+    except FileNotFoundError:
+        print(f"Error: File not found - {args.input}")
+        print("Please check the file path or use --input to specify correct path")
+        print("Run with --help for usage information")
+    except Exception as e:
+        print(f"Error: {e}")

+ 159 - 0
examples/grpo/cosyvoice2/run.sh

@@ -0,0 +1,159 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=4
+
+log() {
+  # This function is from espnet
+  local fname=${BASH_SOURCE[1]##*/}
+  echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+export PYTHONPATH=/workspace/CosyVoice
+model_scope_model_path=./CosyVoice2-0.5B
+sft_model_path=./transformers_cosyvoice2_llm
+
+if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then
+  log "stage -2: install dependencies locally if pre-built docker image is not available"
+  conda create -n cosyvoice2 python=3.10 -y
+  conda activate cosyvoice2
+    # install verl
+  git clone https://github.com/yuekaizhang/verl.git -b thread
+  cd verl
+  USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
+  pip install --no-deps -e .
+  cd -
+  # install requirements
+  pip install -r requirements.txt
+  pip install -U nvidia-pytriton
+  git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e .
+fi
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+  log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
+  modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path 
+  python3 pretrained_to_huggingface.py \
+    --pretrained-cosyvoice2-path $model_scope_model_path \
+    --save-path $sft_model_path
+
+  # Or, you could use the following command to download the huggingface compatible checkpoint
+  # huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm
+
+  # Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers.
+fi
+
+data_dir=data/parquet_aishell3
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+  log "stage 0: prepare data into verl format"
+  mkdir -p $data_dir
+  wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl
+  # total 88035 samples
+  head -n 80000 data/aishell-3.jsonl > data/train.jsonl
+  tail -n 100 data/aishell-3.jsonl > data/test.jsonl
+  python prepare_data.py \
+    --train_file data/train.jsonl \
+    --test_file data/test.jsonl \
+    --local_dir $data_dir
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+  log "stage 1: start token2wav asr server for reward function"
+  python3 token2wav_asr_server.py --number-of-devices 8
+fi 
+
+exp_name=official_llm_aishell3_grpo
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+  log "stage 2: grpo train"
+  export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+  export MKL_SERVICE_FORCE_INTEL=TRUE
+  n_gpus_per_node=8
+  micro_batch_size=4
+  train_batch_size=32
+  python3 -m verl.trainer.main_ppo \
+      algorithm.adv_estimator=grpo \
+      data.train_files=$data_dir/train.parquet \
+      data.val_files=$data_dir/test.parquet \
+      data.train_batch_size=$train_batch_size \
+      data.max_prompt_length=1024 \
+      data.max_response_length=512 \
+      data.truncation='error' \
+      actor_rollout_ref.model.use_remove_padding=False \
+      actor_rollout_ref.model.path=$sft_model_path \
+      actor_rollout_ref.actor.optim.lr=1e-6 \
+      actor_rollout_ref.actor.ppo_mini_batch_size=32 \
+      actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
+      actor_rollout_ref.actor.use_kl_loss=False \
+      actor_rollout_ref.model.enable_gradient_checkpointing=True \
+      actor_rollout_ref.actor.fsdp_config.param_offload=False \
+      actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+      actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
+      actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+      actor_rollout_ref.rollout.name=vllm \
+      actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+      actor_rollout_ref.rollout.do_sample=true \
+      actor_rollout_ref.rollout.temperature=0.8 \
+      actor_rollout_ref.rollout.top_p=0.95 \
+      actor_rollout_ref.rollout.top_k=25 \
+      actor_rollout_ref.rollout.n=4 \
+      actor_rollout_ref.rollout.val_kwargs.do_sample=true \
+      actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
+      actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
+      actor_rollout_ref.rollout.val_kwargs.top_k=25 \
+      reward_model.reward_manager=prime \
+      custom_reward_function.path=reward_tts.py \
+      custom_reward_function.name=compute_score \
+      trainer.project_name='cosyvoice2_grpo' \
+      trainer.experiment_name=$exp_name \
+      trainer.logger=['console','wandb'] \
+      trainer.n_gpus_per_node=$n_gpus_per_node \
+      trainer.nnodes=1 \
+      trainer.save_freq=100 \
+      trainer.test_freq=100 \
+      trainer.resume_mode='auto' \
+      trainer.total_epochs=1 \
+      trainer.val_before_train=False
+fi
+
+steps=(100 200 300 400 500)
+for step in ${steps[@]}; do
+llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step}
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+  log "stage 3: merge the model"
+  python -m verl.model_merger merge \
+      --backend fsdp \
+      --local_dir $llm_path/actor \
+      --target_dir $llm_path/merged_hf_model || exit 1
+fi 
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+  log "stage 4: Test the model"
+  dataset=zero_shot_zh # from CosyVoice3 test set
+  # dataset=test_zh # from seed_tts test set
+  output_dir=./outputs_${exp_name}_${step}_${dataset}
+
+  token2wav_path=/workspace/CosyVoice2-0.5B
+  model_path=$llm_path/merged_hf_model
+
+  CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+  torchrun --nproc_per_node=8 \
+      infer_dataset.py \
+        --output-dir $output_dir \
+        --llm-model-name-or-path $model_path \
+        --token2wav-path $token2wav_path \
+        --split-name ${dataset} || exit 1
+
+  bash scripts/compute_wer.sh $output_dir ${dataset}
+fi
+done
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+  log "stage 5: Convert the RL trained model to CosyVoice repo format"
+  python3 huggingface_to_pretrained.py \
+    --hf-cosyvoice2-llm-path $llm_path/merged_hf_model \
+    --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
+  # You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt
+  # However, we found that the RL trained model accuracy would slightly drop after this conversion.
+  # Please be careful or use the huggingface format inference code.
+fi

+ 33 - 0
examples/grpo/cosyvoice2/scripts/compute_wer.sh

@@ -0,0 +1,33 @@
+wav_dir=$1
+wav_files=$(ls $wav_dir/*.wav)
+# if wav_files is empty, then exit
+if [ -z "$wav_files" ]; then
+    exit 1
+fi
+split_name=$2
+model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
+
+if [ ! -d $model_path ]; then
+    pip install sherpa-onnx
+    wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
+    mkdir models
+    tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
+fi
+
+python3 scripts/offline-decode-files.py  \
+    --tokens=$model_path/tokens.txt \
+    --paraformer=$model_path/model.int8.onnx \
+    --num-threads=2 \
+    --decoding-method=greedy_search \
+    --debug=false \
+    --sample-rate=24000 \
+    --log-dir $wav_dir \
+    --feature-dim=80 \
+    --split-name $split_name \
+    --name sherpa_onnx \
+    $wav_files
+
+# python3 scripts/paraformer-pytriton-client.py  \
+#     --log-dir $wav_dir \
+#     --split-name $split_name \
+#     $wav_files

+ 756 - 0
examples/grpo/cosyvoice2/scripts/offline-decode-files.py

@@ -0,0 +1,756 @@
+#!/usr/bin/env python3
+#
+# Copyright (c)  2023 by manyeyes
+# Copyright (c)  2023  Xiaomi Corporation
+
+"""
+This file demonstrates how to use sherpa-onnx Python API to transcribe
+file(s) with a non-streaming model.
+
+(1) For paraformer
+
+    ./python-api-examples/offline-decode-files.py  \
+      --tokens=/path/to/tokens.txt \
+      --paraformer=/path/to/paraformer.onnx \
+      --num-threads=2 \
+      --decoding-method=greedy_search \
+      --debug=false \
+      --sample-rate=16000 \
+      --feature-dim=80 \
+      /path/to/0.wav \
+      /path/to/1.wav
+
+(2) For transducer models from icefall
+
+    ./python-api-examples/offline-decode-files.py  \
+      --tokens=/path/to/tokens.txt \
+      --encoder=/path/to/encoder.onnx \
+      --decoder=/path/to/decoder.onnx \
+      --joiner=/path/to/joiner.onnx \
+      --num-threads=2 \
+      --decoding-method=greedy_search \
+      --debug=false \
+      --sample-rate=16000 \
+      --feature-dim=80 \
+      /path/to/0.wav \
+      /path/to/1.wav
+
+(3) For CTC models from NeMo
+
+python3 ./python-api-examples/offline-decode-files.py \
+  --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
+  --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
+  --num-threads=2 \
+  --decoding-method=greedy_search \
+  --debug=false \
+  ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
+  ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
+  ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
+
+(4) For Whisper models
+
+python3 ./python-api-examples/offline-decode-files.py \
+  --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
+  --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
+  --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
+  --whisper-task=transcribe \
+  --num-threads=1 \
+  ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
+  ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
+  ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
+
+(5) For CTC models from WeNet
+
+python3 ./python-api-examples/offline-decode-files.py \
+  --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
+  --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
+  ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
+  ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
+  ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
+
+(6) For tdnn models of the yesno recipe from icefall
+
+python3 ./python-api-examples/offline-decode-files.py \
+  --sample-rate=8000 \
+  --feature-dim=23 \
+  --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
+  --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
+  ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
+  ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
+  ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
+
+Please refer to
+https://k2-fsa.github.io/sherpa/onnx/index.html
+to install sherpa-onnx and to download non-streaming pre-trained models
+used in this file.
+"""
+import argparse
+import time
+import wave
+from pathlib import Path
+from typing import List, Tuple, Dict, Iterable, TextIO, Union
+
+import numpy as np
+import sherpa_onnx
+import soundfile as sf
+from datasets import load_dataset
+import logging
+from collections import defaultdict
+import kaldialign
+from zhon.hanzi import punctuation
+import string
+punctuation_all = punctuation + string.punctuation
+Pathlike = Union[str, Path]
+
+
+def remove_punctuation(text: str) -> str:
+    for x in punctuation_all:
+        if x == '\'':
+            continue
+        text = text.replace(x, '')
+    return text
+
+
+def store_transcripts(
+    filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
+) -> None:
+    """Save predicted results and reference transcripts to a file.
+
+    Args:
+      filename:
+        File to save the results to.
+      texts:
+        An iterable of tuples. The first element is the cur_id, the second is
+        the reference transcript and the third element is the predicted result.
+        If it is a multi-talker ASR system, the ref and hyp may also be lists of
+        strings.
+    Returns:
+      Return None.
+    """
+    with open(filename, "w", encoding="utf8") as f:
+        for cut_id, ref, hyp in texts:
+            if char_level:
+                ref = list("".join(ref))
+                hyp = list("".join(hyp))
+            print(f"{cut_id}:\tref={ref}", file=f)
+            print(f"{cut_id}:\thyp={hyp}", file=f)
+
+
+def write_error_stats(
+    f: TextIO,
+    test_set_name: str,
+    results: List[Tuple[str, str]],
+    enable_log: bool = True,
+    compute_CER: bool = False,
+    sclite_mode: bool = False,
+) -> float:
+    """Write statistics based on predicted results and reference transcripts.
+
+    It will write the following to the given file:
+
+        - WER
+        - number of insertions, deletions, substitutions, corrects and total
+          reference words. For example::
+
+              Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
+              reference words (2337 correct)
+
+        - The difference between the reference transcript and predicted result.
+          An instance is given below::
+
+            THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
+
+          The above example shows that the reference word is `EDISON`,
+          but it is predicted to `ADDISON` (a substitution error).
+
+          Another example is::
+
+            FOR THE FIRST DAY (SIR->*) I THINK
+
+          The reference word `SIR` is missing in the predicted
+          results (a deletion error).
+      results:
+        An iterable of tuples. The first element is the cut_id, the second is
+        the reference transcript and the third element is the predicted result.
+      enable_log:
+        If True, also print detailed WER to the console.
+        Otherwise, it is written only to the given file.
+    Returns:
+      Return None.
+    """
+    subs: Dict[Tuple[str, str], int] = defaultdict(int)
+    ins: Dict[str, int] = defaultdict(int)
+    dels: Dict[str, int] = defaultdict(int)
+
+    # `words` stores counts per word, as follows:
+    #   corr, ref_sub, hyp_sub, ins, dels
+    words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
+    num_corr = 0
+    ERR = "*"
+
+    if compute_CER:
+        for i, res in enumerate(results):
+            cut_id, ref, hyp = res
+            ref = list("".join(ref))
+            hyp = list("".join(hyp))
+            results[i] = (cut_id, ref, hyp)
+
+    for cut_id, ref, hyp in results:
+        ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
+        for ref_word, hyp_word in ali:
+            if ref_word == ERR:
+                ins[hyp_word] += 1
+                words[hyp_word][3] += 1
+            elif hyp_word == ERR:
+                dels[ref_word] += 1
+                words[ref_word][4] += 1
+            elif hyp_word != ref_word:
+                subs[(ref_word, hyp_word)] += 1
+                words[ref_word][1] += 1
+                words[hyp_word][2] += 1
+            else:
+                words[ref_word][0] += 1
+                num_corr += 1
+    ref_len = sum([len(r) for _, r, _ in results])
+    sub_errs = sum(subs.values())
+    ins_errs = sum(ins.values())
+    del_errs = sum(dels.values())
+    tot_errs = sub_errs + ins_errs + del_errs
+    tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
+
+    if enable_log:
+        logging.info(
+            f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
+            f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
+            f"{del_errs} del, {sub_errs} sub ]"
+        )
+
+    print(f"%WER = {tot_err_rate}", file=f)
+    print(
+        f"Errors: {ins_errs} insertions, {del_errs} deletions, "
+        f"{sub_errs} substitutions, over {ref_len} reference "
+        f"words ({num_corr} correct)",
+        file=f,
+    )
+    print(
+        "Search below for sections starting with PER-UTT DETAILS:, "
+        "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
+        file=f,
+    )
+
+    print("", file=f)
+    print("PER-UTT DETAILS: corr or (ref->hyp)  ", file=f)
+    for cut_id, ref, hyp in results:
+        ali = kaldialign.align(ref, hyp, ERR)
+        combine_successive_errors = True
+        if combine_successive_errors:
+            ali = [[[x], [y]] for x, y in ali]
+            for i in range(len(ali) - 1):
+                if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
+                    ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
+                    ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
+                    ali[i] = [[], []]
+            ali = [
+                [
+                    list(filter(lambda a: a != ERR, x)),
+                    list(filter(lambda a: a != ERR, y)),
+                ]
+                for x, y in ali
+            ]
+            ali = list(filter(lambda x: x != [[], []], ali))
+            ali = [
+                [
+                    ERR if x == [] else " ".join(x),
+                    ERR if y == [] else " ".join(y),
+                ]
+                for x, y in ali
+            ]
+
+        print(
+            f"{cut_id}:\t"
+            + " ".join(
+                (
+                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
+                    for ref_word, hyp_word in ali
+                )
+            ),
+            file=f,
+        )
+
+    print("", file=f)
+    print("SUBSTITUTIONS: count ref -> hyp", file=f)
+
+    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+        print(f"{count}   {ref} -> {hyp}", file=f)
+
+    print("", file=f)
+    print("DELETIONS: count ref", file=f)
+    for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
+        print(f"{count}   {ref}", file=f)
+
+    print("", file=f)
+    print("INSERTIONS: count hyp", file=f)
+    for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
+        print(f"{count}   {hyp}", file=f)
+
+    print("", file=f)
+    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
+    for _, word, counts in sorted(
+        [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
+    ):
+        (corr, ref_sub, hyp_sub, ins, dels) = counts
+        tot_errs = ref_sub + hyp_sub + ins + dels
+        ref_count = corr + ref_sub + dels
+        hyp_count = corr + hyp_sub + ins
+
+        print(f"{word}   {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
+    return float(tot_err_rate)
+
+
+def get_args():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--tokens",
+        type=str,
+        help="Path to tokens.txt",
+    )
+
+    parser.add_argument(
+        "--hotwords-file",
+        type=str,
+        default="",
+        help="""
+        The file containing hotwords, one words/phrases per line, like
+        HELLO WORLD
+        你好世界
+        """,
+    )
+
+    parser.add_argument(
+        "--hotwords-score",
+        type=float,
+        default=1.5,
+        help="""
+        The hotword score of each token for biasing word/phrase. Used only if
+        --hotwords-file is given.
+        """,
+    )
+
+    parser.add_argument(
+        "--modeling-unit",
+        type=str,
+        default="",
+        help="""
+        The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
+        Used only when hotwords-file is given.
+        """,
+    )
+
+    parser.add_argument(
+        "--bpe-vocab",
+        type=str,
+        default="",
+        help="""
+        The path to the bpe vocabulary, the bpe vocabulary is generated by
+        sentencepiece, you can also export the bpe vocabulary through a bpe model
+        by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
+        and modeling-unit is bpe or cjkchar+bpe.
+        """,
+    )
+
+    parser.add_argument(
+        "--encoder",
+        default="",
+        type=str,
+        help="Path to the encoder model",
+    )
+
+    parser.add_argument(
+        "--decoder",
+        default="",
+        type=str,
+        help="Path to the decoder model",
+    )
+
+    parser.add_argument(
+        "--joiner",
+        default="",
+        type=str,
+        help="Path to the joiner model",
+    )
+
+    parser.add_argument(
+        "--paraformer",
+        default="",
+        type=str,
+        help="Path to the model.onnx from Paraformer",
+    )
+
+    parser.add_argument(
+        "--nemo-ctc",
+        default="",
+        type=str,
+        help="Path to the model.onnx from NeMo CTC",
+    )
+
+    parser.add_argument(
+        "--wenet-ctc",
+        default="",
+        type=str,
+        help="Path to the model.onnx from WeNet CTC",
+    )
+
+    parser.add_argument(
+        "--tdnn-model",
+        default="",
+        type=str,
+        help="Path to the model.onnx for the tdnn model of the yesno recipe",
+    )
+
+    parser.add_argument(
+        "--num-threads",
+        type=int,
+        default=1,
+        help="Number of threads for neural network computation",
+    )
+
+    parser.add_argument(
+        "--whisper-encoder",
+        default="",
+        type=str,
+        help="Path to whisper encoder model",
+    )
+
+    parser.add_argument(
+        "--whisper-decoder",
+        default="",
+        type=str,
+        help="Path to whisper decoder model",
+    )
+
+    parser.add_argument(
+        "--whisper-language",
+        default="",
+        type=str,
+        help="""It specifies the spoken language in the input audio file.
+        Example values: en, fr, de, zh, jp.
+        Available languages for multilingual models can be found at
+        https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
+        If not specified, we infer the language from the input audio file.
+        """,
+    )
+
+    parser.add_argument(
+        "--whisper-task",
+        default="transcribe",
+        choices=["transcribe", "translate"],
+        type=str,
+        help="""For multilingual models, if you specify translate, the output
+        will be in English.
+        """,
+    )
+
+    parser.add_argument(
+        "--whisper-tail-paddings",
+        default=-1,
+        type=int,
+        help="""Number of tail padding frames.
+        We have removed the 30-second constraint from whisper, so you need to
+        choose the amount of tail padding frames by yourself.
+        Use -1 to use a default value for tail padding.
+        """,
+    )
+
+    parser.add_argument(
+        "--blank-penalty",
+        type=float,
+        default=0.0,
+        help="""
+        The penalty applied on blank symbol during decoding.
+        Note: It is a positive value that would be applied to logits like
+        this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+        [batch_size, vocab] and blank id is 0).
+        """,
+    )
+
+    parser.add_argument(
+        "--decoding-method",
+        type=str,
+        default="greedy_search",
+        help="Valid values are greedy_search and modified_beam_search",
+    )
+    parser.add_argument(
+        "--debug",
+        type=bool,
+        default=False,
+        help="True to show debug messages",
+    )
+
+    parser.add_argument(
+        "--sample-rate",
+        type=int,
+        default=16000,
+        help="""Sample rate of the feature extractor. Must match the one
+        expected  by the model. Note: The input sound files can have a
+        different sample rate from this argument.""",
+    )
+
+    parser.add_argument(
+        "--feature-dim",
+        type=int,
+        default=80,
+        help="Feature dimension. Must match the one expected by the model",
+    )
+
+    parser.add_argument(
+        "sound_files",
+        type=str,
+        nargs="+",
+        help="The input sound file(s) to decode. Each file must be of WAVE"
+        "format with a single channel, and each sample has 16-bit, "
+        "i.e., int16_t. "
+        "The sample rate of the file can be arbitrary and does not need to "
+        "be 16 kHz",
+    )
+
+    parser.add_argument(
+        "--name",
+        type=str,
+        default="",
+        help="The directory containing the input sound files to decode",
+    )
+
+    parser.add_argument(
+        "--log-dir",
+        type=str,
+        default="",
+        help="The directory containing the input sound files to decode",
+    )
+
+    parser.add_argument(
+        "--label",
+        type=str,
+        default=None,
+        help="wav_base_name label",
+    )
+
+    # Dataset related arguments for loading labels when label file is not provided
+    parser.add_argument(
+        "--dataset-name",
+        type=str,
+        default="yuekai/seed_tts_cosy2",
+        help="Huggingface dataset name for loading labels",
+    )
+
+    parser.add_argument(
+        "--split-name",
+        type=str,
+        default="wenetspeech4tts",
+        help="Dataset split name for loading labels",
+    )
+
+    return parser.parse_args()
+
+
+def assert_file_exists(filename: str):
+    assert Path(filename).is_file(), (
+        f"{filename} does not exist!\n"
+        "Please refer to "
+        "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
+    )
+
+
+def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
+    """
+    Args:
+      wave_filename:
+        Path to a wave file. It should be single channel and can be of type
+        32-bit floating point PCM. Its sample rate does not need to be 24kHz.
+
+    Returns:
+      Return a tuple containing:
+       - A 1-D array of dtype np.float32 containing the samples,
+         which are normalized to the range [-1, 1].
+       - Sample rate of the wave file.
+    """
+
+    samples, sample_rate = sf.read(wave_filename, dtype="float32")
+    assert (
+        samples.ndim == 1
+    ), f"Expected single channel, but got {samples.ndim} channels."
+
+    samples_float32 = samples.astype(np.float32)
+
+    return samples_float32, sample_rate
+
+
+def normalize_text_alimeeting(text: str) -> str:
+    """
+    Text normalization similar to M2MeT challenge baseline.
+    See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
+    """
+    import re
+    text = text.replace('\u00A0', '')  # test_hard
+    text = text.replace(" ", "")
+    text = text.replace("<sil>", "")
+    text = text.replace("<%>", "")
+    text = text.replace("<->", "")
+    text = text.replace("<$>", "")
+    text = text.replace("<#>", "")
+    text = text.replace("<_>", "")
+    text = text.replace("<space>", "")
+    text = text.replace("`", "")
+    text = text.replace("&", "")
+    text = text.replace(",", "")
+    if re.search("[a-zA-Z]", text):
+        text = text.upper()
+    text = text.replace("A", "A")
+    text = text.replace("a", "A")
+    text = text.replace("b", "B")
+    text = text.replace("c", "C")
+    text = text.replace("k", "K")
+    text = text.replace("t", "T")
+    text = text.replace(",", "")
+    text = text.replace("丶", "")
+    text = text.replace("。", "")
+    text = text.replace("、", "")
+    text = text.replace("?", "")
+    text = remove_punctuation(text)
+    return text
+
+
+def main():
+    args = get_args()
+    assert_file_exists(args.tokens)
+    assert args.num_threads > 0, args.num_threads
+
+    assert len(args.nemo_ctc) == 0, args.nemo_ctc
+    assert len(args.wenet_ctc) == 0, args.wenet_ctc
+    assert len(args.whisper_encoder) == 0, args.whisper_encoder
+    assert len(args.whisper_decoder) == 0, args.whisper_decoder
+    assert len(args.tdnn_model) == 0, args.tdnn_model
+
+    assert_file_exists(args.paraformer)
+
+    recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
+        paraformer=args.paraformer,
+        tokens=args.tokens,
+        num_threads=args.num_threads,
+        sample_rate=args.sample_rate,
+        feature_dim=args.feature_dim,
+        decoding_method=args.decoding_method,
+        debug=args.debug,
+    )
+
+    print("Started!")
+    start_time = time.time()
+
+    streams, results = [], []
+    total_duration = 0
+
+    for i, wave_filename in enumerate(args.sound_files):
+        assert_file_exists(wave_filename)
+        samples, sample_rate = read_wave(wave_filename)
+        duration = len(samples) / sample_rate
+        total_duration += duration
+        s = recognizer.create_stream()
+        s.accept_waveform(sample_rate, samples)
+
+        streams.append(s)
+        if i % 10 == 0:
+            recognizer.decode_streams(streams)
+            results += [s.result.text for s in streams]
+            streams = []
+            print(f"Processed {i} files")
+        # process the last batch
+    if streams:
+        recognizer.decode_streams(streams)
+        results += [s.result.text for s in streams]
+    end_time = time.time()
+    print("Done!")
+
+    results_dict = {}
+    for wave_filename, result in zip(args.sound_files, results):
+        print(f"{wave_filename}\n{result}")
+        print("-" * 10)
+        wave_basename = Path(wave_filename).stem
+        results_dict[wave_basename] = result
+
+    elapsed_seconds = end_time - start_time
+    rtf = elapsed_seconds / total_duration
+    print(f"num_threads: {args.num_threads}")
+    print(f"decoding_method: {args.decoding_method}")
+    print(f"Wave duration: {total_duration:.3f} s")
+    print(f"Elapsed time: {elapsed_seconds:.3f} s")
+    print(
+        f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
+    )
+
+    # Load labels either from file or from dataset
+    labels_dict = {}
+
+    if args.label:
+        # Load labels from file (original functionality)
+        print(f"Loading labels from file: {args.label}")
+        with open(args.label, "r") as f:
+            for line in f:
+                # fields = line.strip().split(" ")
+                # fields = [item for item in fields if item]
+                # assert len(fields) == 4
+                # prompt_text, prompt_audio, text, audio_path = fields
+
+                fields = line.strip().split("|")
+                fields = [item for item in fields if item]
+                assert len(fields) == 4
+                audio_path, prompt_text, prompt_audio, text = fields
+                labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
+    else:
+        # Load labels from dataset (new functionality)
+        print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
+        if 'zero' in args.split_name:
+            dataset_name = "yuekai/CV3-Eval"
+        else:
+            dataset_name = "yuekai/seed_tts_cosy2"
+        dataset = load_dataset(
+            dataset_name,
+            split=args.split_name,
+            trust_remote_code=True,
+        )
+
+        for item in dataset:
+            audio_id = item["id"]
+            labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
+
+        print(f"Loaded {len(labels_dict)} labels from dataset")
+
+    # Perform evaluation if labels are available
+    if labels_dict:
+
+        final_results = []
+        for key, value in results_dict.items():
+            if key in labels_dict:
+                final_results.append((key, labels_dict[key], value))
+            else:
+                print(f"Warning: No label found for {key}, skipping...")
+
+        if final_results:
+            store_transcripts(
+                filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
+            )
+            with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
+                write_error_stats(f, "test-set", final_results, enable_log=True)
+
+            with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
+                print(f.readline())  # WER
+                print(f.readline())  # Detailed errors
+        else:
+            print("No matching labels found for evaluation")
+    else:
+        print("No labels available for evaluation")
+
+
+if __name__ == "__main__":
+    main()

+ 346 - 0
examples/grpo/cosyvoice2/token2wav_asr_server.py

@@ -0,0 +1,346 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Pytriton server for token2wav conversion and ASR"""
+
+from datasets import load_dataset
+from cosyvoice.cli.cosyvoice import CosyVoice2
+from omnisense.models import OmniSenseVoiceSmall
+from pytriton.proxy.types import Request
+from pytriton.triton import Triton, TritonConfig
+from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
+from pytriton.decorators import batch
+import argparse
+import io
+import logging
+from typing import Any, List
+import numpy as np
+import torch
+from scipy.signal import resample
+import sys
+import random
+import re
+from jiwer import wer
+from pypinyin import lazy_pinyin, Style
+from tn.chinese.normalizer import Normalizer as ZhNormalizer
+
+# Chinese text normalizer (cached globally)
+zh_tn_model = ZhNormalizer(
+    cache_dir="./cache",
+    remove_erhua=False,
+    remove_interjections=False,
+    remove_puncts=True,
+    overwrite_cache=True,
+)
+
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+
+logger = logging.getLogger("token2wav_asr_server")
+
+
+class _ASR_Server:
+    """Wraps a single OmniSenseVoiceSmall model instance for Triton."""
+
+    def __init__(self, device_id: int):
+        self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
+
+    @batch
+    def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
+        """
+        WAV: np.ndarray, WAV_LENS: np.ndarray
+        LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
+        See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
+        """
+        logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
+        wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
+
+        results = self._model.transcribe_single_batch(
+            wavs,
+            language="zh",
+            textnorm="woitn",
+        )
+        texts = [result.text for result in results]
+        transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
+        return {"TRANSCRIPTS": transcripts}
+
+
+def audio_decode_cosyvoice2(
+    audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
+):
+    """
+    Generate audio from tokens with optional tone and prompt embedding.
+    """
+    model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
+        "empty", prompt_text, prompt_speech_16k, 24000
+    )
+    tts_mel, _ = codec_decoder.model.flow.inference(
+        token=audio_tokens.to(codec_decoder.model.device),
+        token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
+            codec_decoder.model.device
+        ),
+        prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
+            codec_decoder.model.device
+        ),
+        prompt_token_len=torch.tensor(
+            [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
+        ).to(codec_decoder.model.device),
+        prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
+            codec_decoder.model.device
+        ),
+        prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
+            codec_decoder.model.device
+        ),
+        embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
+        finalize=True,
+    )
+
+    audio_hat, _ = codec_decoder.model.hift.inference(
+        speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+    )
+
+    return audio_hat
+
+
+def get_random_prompt_from_dataset(dataset):
+    """
+    Get random prompt text and speech from the pre-loaded dataset.
+    Returns (prompt_text, prompt_speech_16k)
+    """
+    random_idx = random.randint(0, len(dataset) - 1)
+    sample = dataset[random_idx]
+
+    # Extract audio data
+    audio_data = sample["audio"]
+    audio_array = audio_data["array"]
+    sample_rate = audio_data["sampling_rate"]
+
+    # Convert audio to 16kHz if needed
+    if sample_rate != 16000:
+        num_samples = int(len(audio_array) * (16000 / sample_rate))
+        audio_array = resample(audio_array, num_samples)
+
+    # Convert to torch tensor
+    prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
+    prompt_text = sample["text"]
+    # remove space in prompt_text
+    prompt_text = prompt_text.replace(" ", "")
+    return prompt_text, prompt_speech_16k
+
+
+class _Token2Wav_ASR:
+    """Wraps a single OmniSenseVoiceSmall model instance for Triton."""
+
+    def __init__(self, device_id: int):
+        self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
+        self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
+
+        # Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
+        # CosyVoice2 internally uses generic "cuda" device, so we first switch the
+        # current CUDA context to the desired card before the object is created.
+        # Afterwards, all parameters loaded with the generic "cuda" device will
+        # reside on this GPU.  We keep the selected id in `self.device_id` and
+        # will set the context again for every forward call to avoid race
+        # conditions when several instances are used in the same process.
+
+        self.device_id = device_id
+
+        # Construct the TTS codec decoder under the correct CUDA device context
+        with torch.cuda.device(self.device_id):
+            self.codec_decoder = CosyVoice2(
+                "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
+            )
+
+    @batch
+    def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
+        """
+        WAV: np.ndarray, WAV_LENS: np.ndarray
+        LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
+        See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
+        """
+        # Ensure the default CUDA device is set correctly for this invocation
+        torch.cuda.set_device(self.device_id)
+
+        if self.device_id == 0:
+            print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
+
+        tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
+
+        # Decode ground-truth text strings (BYTES → str)
+        if GT_TEXT.ndim == 2:
+            gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
+        else:
+            gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
+
+        wavs = []
+        for tokens in tokens_list:
+            prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
+            audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
+            audio_hat = audio_decode_cosyvoice2(
+                audio_tokens,
+                prompt_text,
+                prompt_speech_16k,
+                self.codec_decoder,
+            )
+            # resample to 16000 using soundfile
+            audio_hat = audio_hat.squeeze(0).float().cpu()
+            audio_hat = audio_hat.numpy()
+            num_samples = int(len(audio_hat) * (16000 / 24000))
+            audio_hat = resample(audio_hat, num_samples)
+            wavs.append(audio_hat)
+
+        results = self.asr_model.transcribe_single_batch(
+            wavs,
+            language="zh",
+            textnorm="woitn",
+        )
+        texts = [result.text for result in results]
+
+        # ---------------- Reward computation ----------------
+        rewards = []
+        for gt_text, hyp_text in zip(gt_texts, texts):
+            gt_norm = zh_tn_model.normalize(gt_text).lower()
+            hyp_norm = zh_tn_model.normalize(hyp_text).lower()
+
+            gt_pinyin = lazy_pinyin(
+                gt_norm,
+                style=Style.TONE3,
+                tone_sandhi=True,
+                neutral_tone_with_five=True,
+            )
+            hyp_pinyin = lazy_pinyin(
+                hyp_norm,
+                style=Style.TONE3,
+                tone_sandhi=True,
+                neutral_tone_with_five=True,
+            )
+
+            c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
+            reward_val = 1.0 - np.tanh(3.0 * c)
+            reward_val = max(0.0, min(1.0, reward_val))
+            rewards.append(reward_val)
+            print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
+
+        transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
+        rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
+
+        return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
+
+
+def _infer_function_factory(device_ids: List[int], model_name: str):
+    """Creates a list of inference functions, one for each requested device ID."""
+    infer_funcs = []
+    for device_id in device_ids:
+        if model_name == "sensevoice":
+            infer_funcs.append(_ASR_Server(device_id=device_id))
+        else:
+            infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
+    return infer_funcs
+
+
+def main():
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument(
+        "--max-batch-size",
+        type=int,
+        default=32,
+        help="Batch size of request.",
+        required=False,
+    )
+    parser.add_argument(
+        "--verbose",
+        action="store_true",
+        default=False,
+    )
+    parser.add_argument(
+        "--number-of-instances-per-device",
+        type=int,
+        default=1,
+        help="Number of model instances to load.",
+        required=False,
+    )
+    parser.add_argument(
+        "--number-of-devices",
+        type=int,
+        default=8,
+        help="Number of devices to use.",
+    )
+    parser.add_argument(
+        "--model-name",
+        type=str,
+        default="token2wav_asr",
+        choices=["token2wav_asr", "sensevoice"],
+        help="Model name.",
+    )
+
+    args = parser.parse_args()
+
+    log_level = logging.DEBUG if args.verbose else logging.INFO
+    logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
+
+    triton_config = TritonConfig(
+        http_port=8000,
+        grpc_port=8001,
+        metrics_port=8002,
+    )
+
+    device_ids = [i for i in range(args.number_of_devices)]
+    device_ids = device_ids * args.number_of_instances_per_device
+
+    with Triton(config=triton_config) as triton:
+        logger.info("Loading SenseVoice model on device ids: %s", device_ids)
+        if args.model_name == "sensevoice":
+            triton.bind(
+                model_name="sensevoice",
+                infer_func=_infer_function_factory(device_ids, args.model_name),
+                inputs=[
+                    Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
+                    Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
+                    Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
+                    Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
+                ],
+                outputs=[
+                    Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
+                ],
+                config=ModelConfig(
+                    max_batch_size=args.max_batch_size,
+                    batcher=DynamicBatcher(max_queue_delay_microseconds=10000),  # 10ms
+                ),
+                strict=True,
+            )
+        else:
+            triton.bind(
+                model_name="token2wav_asr",
+                infer_func=_infer_function_factory(device_ids, args.model_name),
+                inputs=[
+                    Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
+                    Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
+                    Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
+                ],
+                outputs=[
+                    Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
+                    Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
+                ],
+                config=ModelConfig(
+                    max_batch_size=args.max_batch_size,
+                    batcher=DynamicBatcher(max_queue_delay_microseconds=10000),  # 10ms
+                ),
+                strict=True,
+            )
+        logger.info("Serving inference")
+        triton.serve()
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
runtime/python/Dockerfile

@@ -9,5 +9,5 @@ RUN apt-get -y install git unzip git-lfs g++
 RUN git lfs install
 RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
 # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
-RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
+RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
 RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto

+ 92 - 37
runtime/triton_trtllm/README.md

@@ -1,15 +1,17 @@
-## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server
+## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
 
-Thanks to the contribution from NVIDIA Yuekai Zhang.
+Contributed by Yuekai Zhang (NVIDIA).
 
 ### Quick Start
+
 Launch the service directly with Docker Compose:
 ```sh
 docker compose up
 ```
 
 ### Build the Docker Image
-Build the image from scratch:
+
+To build the image from scratch:
 ```sh
 docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
 ```
@@ -21,71 +23,124 @@ docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_di
 ```
 
 ### Understanding `run.sh`
+
 The `run.sh` script orchestrates the entire workflow through numbered stages.
 
-Run a subset of stages with:
+You can run a subset of stages with:
 ```sh
 bash run.sh <start_stage> <stop_stage> [service_type]
 ```
-- `<start_stage>` – stage to start from (0-5).
-- `<stop_stage>`  – stage to stop after (0-5).
-
-Stages:
-- **Stage 0** – Download the cosyvoice-2 0.5B model from HuggingFace.
-- **Stage 1** – Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
-- **Stage 2** – Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later).
-- **Stage 3** – Launch the Triton Inference Server.
-- **Stage 4** – Run the single-utterance HTTP client.
-- **Stage 5** – Run the gRPC benchmark client.
-
-### Export Models to TensorRT-LLM and Launch the Server
+- `<start_stage>`: The stage to start from (0-5).
+- `<stop_stage>`: The stage to stop after (0-5).
+
+**Stages:**
+
+- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace.
+- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines.
+- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used.
+- **Stage 3**: Launches the Triton Inference Server.
+- **Stage 4**: Runs the single-utterance HTTP client for testing.
+- **Stage 5**: Runs the gRPC benchmark client.
+- **Stage 6**: Runs the offline inference benchmark test.
+
+### Export Models and Launch Server
+
 Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
 ```sh
-# Runs stages 0, 1, 2, and 3
+# This command runs stages 0, 1, 2, and 3
 bash run.sh 0 3
 ```
-*Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.*
+> [!TIP]
+> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes.
 
 ### Single-Utterance HTTP Client
-Send a single HTTP inference request:
+
+Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`):
 ```sh
 bash run.sh 4 4
 ```
 
-### Benchmark with a Dataset
-Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument.
+### Benchmark with client-server mode
+
+To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
 ```sh
-bash run.sh 5 5
+bash run.sh 5 5 # [streaming|offline]
 
-# You can also customise parameters such as num_task and dataset split directly:
+# You can also customize parameters such as the number of tasks and the dataset split:
 # python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
 ```
 > [!TIP]
-> Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Triton’s decoupled mode so that responses are returned as soon as they are ready.
+> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
+
+### Benchmark with offline inference mode
+For offline inference mode benchmark, please check the below command:
+```sh
+# install FlashCosyVoice for token2wav batching
+# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
+# cd /workspace/FlashCosyVoice
+# pip install -e .
+# cd -
+# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
+
+bash run.sh 6 6
+
+# You can also switch to huggingface backend by setting backend=hf
+```
+
 
 ### Benchmark Results
-Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio):
+The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
 
-| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
-|------|------|-------------|------------------|------------------|-----|
-| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
-| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
-| Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
-| Decoupled=True  | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 |
-| Decoupled=True  | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 |
-| Decoupled=True  | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 |
+**Client-Server Mode: Streaming TTS (First Chunk Latency)**
+| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
+|---|---|---|---|---|
+| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
+| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
+| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
+| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
+| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
+| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
+
+> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
 
+**Client-Server Mode: Offline TTS (Full Sentence Latency)**
+| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
+|---|---|---|---|---|---|
+| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
+| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
+| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
+
+**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
+| Backend | Batch Size | llm_time_seconds  | total_time_seconds | RTF |
+|---------|------------|------------------|-----------------------|--|
+| HF | 1 | 39.26 |  44.31 | 0.2494 |
+| HF | 2 | 30.54 | 35.62 | 0.2064 |
+| HF | 4 | 18.63 |  23.90 | 0.1421 |
+| HF | 8 | 11.22 | 16.45 | 0.0947 |
+| HF | 16 | 8.42 | 13.78 | 0.0821 |
+| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
+| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
+| TRTLLM | 4 | 4.89 |  9.38 | 0.0539 |
+| TRTLLM | 8 | 2.92 |  7.23 | 0.0418 |
+| TRTLLM | 16 | 2.01 |  6.63 | 0.0386 |
 ### OpenAI-Compatible Server
-To launch an OpenAI-compatible service, run:
+
+To launch an OpenAI-compatible API service, run the following commands:
 ```sh
 git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
+cd Triton-OpenAI-Speech
 pip install -r requirements.txt
-# After the Triton service is up, start the FastAPI bridge:
+
+# After the Triton service is running, start the FastAPI bridge:
 python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
-# Test with curl
+
+# Test the service with curl:
 bash test/test_cosyvoice.sh
 ```
+> [!NOTE]
+> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server.
 
 ### Acknowledgements
-This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details.
+
+This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).
 

+ 55 - 33
runtime/triton_trtllm/client_grpc.py

@@ -257,7 +257,13 @@ def get_args():
         default=0.1,
         help="Chunk overlap duration for streaming reconstruction (in seconds)."
     )
-    # --- End Added arguments ---
+
+    parser.add_argument(
+        "--use-spk2info-cache",
+        type=bool,
+        default=False,
+        help="Use spk2info cache for reference audio.",
+    )
 
     return parser.parse_args()
 
@@ -283,7 +289,8 @@ def prepare_request_input_output(
     reference_text,
     target_text,
     sample_rate=16000,
-    padding_duration: int = None  # Optional padding for offline mode
+    padding_duration: int = None,  # Optional padding for offline mode
+    use_spk2info_cache: bool = False
 ):
     """Prepares inputs for Triton inference (offline or streaming)."""
     assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -330,7 +337,8 @@ def prepare_request_input_output(
     inputs[3].set_data_from_numpy(input_data_numpy)
 
     outputs = [protocol_client.InferRequestedOutput("waveform")]
-
+    if use_spk2info_cache:
+        inputs = inputs[-1:]
     return inputs, outputs
 
 
@@ -395,38 +403,45 @@ def run_sync_streaming_inference(
     # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
     actual_duration = 0
     if audios:
-        cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
-        fade_out = np.linspace(1, 0, cross_fade_samples)
-        fade_in = np.linspace(0, 1, cross_fade_samples)
-        reconstructed_audio = None
-
-        # Simplified reconstruction based on client_grpc_streaming.py
-        if not audios:
-            print("Warning: No audio chunks received.")
-            reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
-        elif len(audios) == 1:
-            reconstructed_audio = audios[0]
+        # Only spark_tts model uses cross-fade
+        if model_name == "spark_tts":
+            cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
+            fade_out = np.linspace(1, 0, cross_fade_samples)
+            fade_in = np.linspace(0, 1, cross_fade_samples)
+            reconstructed_audio = None
+
+            # Simplified reconstruction based on client_grpc_streaming.py
+            if not audios:
+                print("Warning: No audio chunks received.")
+                reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
+            elif len(audios) == 1:
+                reconstructed_audio = audios[0]
+            else:
+                reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
+                for i in range(1, len(audios)):
+                    # Cross-fade section
+                    cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
+                                           audios[i - 1][-cross_fade_samples:] * fade_out)
+                    # Middle section of the current chunk
+                    middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
+                    # Concatenate
+                    reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
+                # Add the last part of the final chunk
+                reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
+
+            if reconstructed_audio is not None and reconstructed_audio.size > 0:
+                actual_duration = len(reconstructed_audio) / save_sample_rate
+                # Save reconstructed audio
+                sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
+            else:
+                print("Warning: No audio chunks received or reconstructed.")
+                actual_duration = 0  # Set duration to 0 if no audio
         else:
-            reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
-            for i in range(1, len(audios)):
-                # Cross-fade section
-                cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
-                                       audios[i - 1][-cross_fade_samples:] * fade_out)
-                # Middle section of the current chunk
-                middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
-                # Concatenate
-                reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
-            # Add the last part of the final chunk
-            reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
-
-        if reconstructed_audio is not None and reconstructed_audio.size > 0:
+            reconstructed_audio = np.concatenate(audios)
+            print(f"reconstructed_audio: {reconstructed_audio.shape}")
             actual_duration = len(reconstructed_audio) / save_sample_rate
             # Save reconstructed audio
-            os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
-        else:
-            print("Warning: No audio chunks received or reconstructed.")
-            actual_duration = 0  # Set duration to 0 if no audio
 
     else:
         print("Warning: No audio chunks received.")
@@ -446,6 +461,7 @@ async def send_streaming(
     save_sample_rate: int = 16000,
     chunk_overlap_duration: float = 0.1,
     padding_duration: int = None,
+    use_spk2info_cache: bool = False,
 ):
     total_duration = 0.0
     latency_data = []
@@ -471,7 +487,8 @@ async def send_streaming(
                     reference_text,
                     target_text,
                     sample_rate,
-                    padding_duration=padding_duration
+                    padding_duration=padding_duration,
+                    use_spk2info_cache=use_spk2info_cache
                 )
                 request_id = str(uuid.uuid4())
                 user_data = UserData()
@@ -527,6 +544,7 @@ async def send(
     padding_duration: int = None,
     audio_save_dir: str = "./",
     save_sample_rate: int = 16000,
+    use_spk2info_cache: bool = False,
 ):
     total_duration = 0.0
     latency_data = []
@@ -545,7 +563,8 @@ async def send(
             reference_text,
             target_text,
             sample_rate,
-            padding_duration=padding_duration
+            padding_duration=padding_duration,
+            use_spk2info_cache=use_spk2info_cache
         )
         sequence_id = 100000000 + i + task_id * 10
         start = time.time()
@@ -667,6 +686,7 @@ async def main():
     manifest_item_list = split_data(manifest_item_list, num_tasks)
 
     os.makedirs(args.log_dir, exist_ok=True)
+
     tasks = []
     start_time = time.time()
     for i in range(num_tasks):
@@ -683,6 +703,7 @@ async def main():
                     audio_save_dir=args.log_dir,
                     padding_duration=1,
                     save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
+                    use_spk2info_cache=args.use_spk2info_cache,
                 )
             )
         elif args.mode == "streaming":
@@ -698,6 +719,7 @@ async def main():
                     padding_duration=10,
                     save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
                     chunk_overlap_duration=args.chunk_overlap_duration,
+                    use_spk2info_cache=args.use_spk2info_cache,
                 )
             )
         # --- End Task Creation ---

+ 1 - 1
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py

@@ -32,7 +32,7 @@ import triton_python_backend_utils as pb_utils
 import os
 import numpy as np
 import s3tokenizer
-
+torch.set_num_threads(1)
 ORIGINAL_VOCAB_SIZE = 151663
 
 

+ 1 - 1
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt

@@ -20,7 +20,7 @@ dynamic_batching {
 }
 parameters [
   {
-   key: "model_dir", 
+   key: "model_dir",
    value: {string_value:"${model_dir}"}
   }
 ]

+ 166 - 57
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -28,6 +28,8 @@ import json
 import math
 import os
 import re
+import threading
+import time
 from typing import Dict, List, Tuple, Optional, Union
 
 import numpy as np
@@ -35,13 +37,15 @@ import torch
 from torch.utils.dlpack import from_dlpack, to_dlpack
 import triton_python_backend_utils as pb_utils
 from transformers import AutoTokenizer
-import torchaudio.compliance.kaldi as kaldi
+
 import torchaudio
-import onnxruntime
 
 
 from matcha.utils.audio import mel_spectrogram
 
+ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
+
 
 class TritonPythonModel:
     """Triton Python model for Spark TTS.
@@ -62,6 +66,8 @@ class TritonPythonModel:
         parameters = self.model_config['parameters']
         model_params = {k: v["string_value"] for k, v in parameters.items()}
         self.logger.log_info(f"model_params:{model_params}")
+        self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
+        self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
 
         # Initialize tokenizer
         llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
@@ -72,11 +78,15 @@ class TritonPythonModel:
         self.device = torch.device("cuda")
         self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
 
-        campplus_model = f'{model_params["model_dir"]}/campplus.onnx'
-        option = onnxruntime.SessionOptions()
-        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
-        option.intra_op_num_threads = 1
-        self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
+        self.token_frame_rate = 25
+        self.flow_pre_lookahead_len = 3
+        self.token_hop_len = 15
+
+        spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
+        if not os.path.exists(spk_info_path):
+            raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
+        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
+        self.default_spk_info = spk_info["001"]
 
     def forward_llm(self, input_ids):
         """
@@ -105,7 +115,7 @@ class TritonPythonModel:
         """
         # convert input_ids to numpy, with shape [1, sequence_length]
         input_ids = input_ids.cpu().numpy()
-        max_tokens = 1024
+        max_tokens = 750
         input_dict = {
             "request_output_len": np.array([[max_tokens]], dtype=np.int32),
             "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
@@ -114,6 +124,8 @@ class TritonPythonModel:
             "runtime_top_p": np.array([[0.95]], dtype=np.float32),
             "runtime_top_k": np.array([[50]], dtype=np.int32),
             "temperature": np.array([[0.8]], dtype=np.float32),
+            "repetition_penalty": np.array([[1.1]], dtype=np.float32),
+            "random_seed": np.array([[42]], dtype=np.uint64),
             "input_ids": input_ids,
             "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
         }
@@ -188,12 +200,40 @@ class TritonPythonModel:
 
         return prompt_speech_tokens
 
+    def forward_speaker_embedding(self, wav):
+        """Forward pass through the speaker embedding component.
+
+        Args:
+            wav: Input waveform tensor
+
+        Returns:
+            Prompt speaker embedding tensor
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='speaker_embedding',
+            requested_output_names=['prompt_spk_embedding'],
+            inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
+        prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
+
+        return prompt_spk_embedding
+
     def forward_token2wav(
             self,
-            prompt_speech_tokens: torch.Tensor,
-            prompt_speech_feat: torch.Tensor,
-            prompt_spk_embedding: torch.Tensor,
-            target_speech_tokens: torch.Tensor) -> torch.Tensor:
+            target_speech_tokens: torch.Tensor,
+            request_id: str,
+            prompt_speech_tokens: torch.Tensor = None,
+            prompt_speech_feat: torch.Tensor = None,
+            prompt_spk_embedding: torch.Tensor = None,
+            token_offset: int = None,
+            finalize: bool = None) -> torch.Tensor:
         """Forward pass through the vocoder component.
 
         Args:
@@ -205,16 +245,30 @@ class TritonPythonModel:
         Returns:
             Generated waveform tensor
         """
-        prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
-        prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
-        prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
 
+        inputs_tensor = [target_speech_tokens_tensor]
+
+        if token_offset is not None:
+            assert finalize is not None
+            token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
+            finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
+            inputs_tensor.append(token_offset_tensor)
+            inputs_tensor.append(finalize_tensor)
+
+        if prompt_spk_embedding is not None:
+            assert prompt_speech_feat is not None
+            prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
+            prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
+            prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
+            inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
+
         # Create and execute inference request
         inference_request = pb_utils.InferenceRequest(
             model_name='token2wav',
             requested_output_names=['waveform'],
-            inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
+            inputs=inputs_tensor,
+            request_id=request_id,
         )
 
         inference_response = inference_request.exec()
@@ -235,17 +289,6 @@ class TritonPythonModel:
         input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
         return input_ids
 
-    def _extract_spk_embedding(self, speech):
-        feat = kaldi.fbank(speech,
-                           num_mel_bins=80,
-                           dither=0,
-                           sample_frequency=16000)
-        feat = feat - feat.mean(dim=0, keepdim=True)
-        embedding = self.campplus_session.run(None,
-                                              {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
-        embedding = torch.tensor([embedding]).to(self.device).half()
-        return embedding
-
     def _extract_speech_feat(self, speech):
         speech_feat = mel_spectrogram(
             speech,
@@ -263,6 +306,14 @@ class TritonPythonModel:
         speech_feat = speech_feat.unsqueeze(dim=0)
         return speech_feat
 
+    def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
+        for generated_ids in generated_ids_iter:
+            generated_ids = generated_ids.tolist()
+            if len(generated_ids) == 0:
+                break
+            semantic_token_ids_arr.extend(generated_ids)
+        llm_is_done_flag[0] = True
+
     def execute(self, requests):
         """Execute inference on the batched requests.
 
@@ -275,25 +326,33 @@ class TritonPythonModel:
         responses = []
 
         for request in requests:
+            request_id = request.request_id()
             # Extract input tensors
             wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
-            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
 
             # Process reference audio through audio tokenizer
-
-            prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
-            prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
-
-            wav_tensor = wav.as_numpy()
-            wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
-            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
-            speech_feat = self._extract_speech_feat(prompt_speech_resample)
-            token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
-            prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
-            prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
-
-            reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
-            reference_text = reference_text[0][0].decode('utf-8')
+            if wav is not None:
+                wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+                prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
+                prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
+
+                wav_tensor = wav.as_numpy()
+                wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
+                prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
+                speech_feat = self._extract_speech_feat(prompt_speech_resample)
+                token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
+                prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
+                prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
+
+                reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+                reference_text = reference_text[0][0].decode('utf-8')
+                prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+            else:
+                # using pre-cached reference text
+                reference_text = self.default_spk_info["prompt_text"]
+                prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
+                prompt_speech_feat = None
+                prompt_spk_embedding = None
 
             target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
             target_text = target_text[0][0].decode('utf-8')
@@ -310,22 +369,73 @@ class TritonPythonModel:
 
             if self.decoupled:
                 response_sender = request.get_response_sender()
-                request_id = request.request_id()
-                generated_ids = []
-                for generated_id in generated_ids_iter:
-                    # convert the numpy array into a int32 tensor
-                    generated_id = generated_id.tolist()
-                    if len(generated_id) > 0:
-                        assert len(generated_id) == 1, "Generated ID is not a single integer"
-                        generated_ids.append(generated_id[0])
-                generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
-                prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
-                audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
 
-                # Prepare response
-                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
+                semantic_token_ids_arr = []
+                llm_is_done_flag = [False]
+
+                llm_thread = threading.Thread(
+                    target=self._llm_gen_thread,
+                    args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
+                )
+
+                llm_thread.start()
+
+                token_offset, chunk_index = 0, 0
+                start_time = time.time()
+                this_token_hop_len = self.token_hop_len
+
+                while True:
+                    pending_num = len(semantic_token_ids_arr) - token_offset
+
+                    if llm_is_done_flag[0]:
+                        break
+
+                    if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
+                        this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
+                        this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
+
+                        sub_tts_speech = self.forward_token2wav(
+                            this_tts_speech_token, request_id, prompt_speech_tokens,
+                            prompt_speech_feat, prompt_spk_embedding, token_offset, False
+                        )
+
+                        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+                        inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                        response_sender.send(inference_response)
+
+                        token_offset += this_token_hop_len
+                        self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
+
+                        if self.dynamic_chunk_strategy == "exponential":
+                            this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
+                        elif self.dynamic_chunk_strategy == "time_based":
+                            # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
+                            cost_time = time.time() - start_time
+                            duration = token_offset / self.token_frame_rate
+                            if chunk_index > 0 and cost_time > 0:
+                                avg_chunk_processing_time = cost_time / (chunk_index + 1)
+                                if avg_chunk_processing_time > 0:
+                                    multiples = (duration - cost_time) / avg_chunk_processing_time
+                                    self.logger.log_info(f"multiples: {multiples}")
+                                    next_pending_num = len(semantic_token_ids_arr) - token_offset
+                                    if multiples > 4:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
+                                    elif multiples > 2:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
+                                    else:
+                                        this_token_hop_len = self.token_hop_len
+                                    this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
+                        chunk_index += 1
+                    else:
+                        time.sleep(0.02)
+
+                this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
+                sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
+                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 response_sender.send(inference_response)
+
+                llm_thread.join()
                 response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                 self.logger.log_info("send tritonserver_response_complete_final to end")
             else:
@@ -334,8 +444,7 @@ class TritonPythonModel:
                 if generated_ids is None or len(generated_ids) == 0:
                     raise pb_utils.TritonModelException("Generated IDs is None or empty")
 
-                prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
-                audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
+                audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
 
                 # Prepare response
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))

+ 5 - 2
runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt

@@ -23,11 +23,11 @@ model_transaction_policy {
 }
 parameters [
   {
-   key: "llm_tokenizer_dir", 
+   key: "llm_tokenizer_dir",
    value: {string_value:"${llm_tokenizer_dir}"}
   },
   {
-   key: "model_dir", 
+   key: "model_dir",
    value: {string_value:"${model_dir}"}
   }
 ]
@@ -37,16 +37,19 @@ input [
     name: "reference_wav"
     data_type: TYPE_FP32
     dims: [-1]
+    optional: true
   },
   {
     name: "reference_wav_len"
     data_type: TYPE_INT32
     dims: [1]
+    optional: true
   },
   {
     name: "reference_text"
     data_type: TYPE_STRING
     dims: [1]
+    optional: true
   },
   {
     name: "target_text"

+ 153 - 0
runtime/triton_trtllm/model_repo/speaker_embedding/1/model.py

@@ -0,0 +1,153 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+import json
+import torch
+from torch.utils.dlpack import to_dlpack
+
+import triton_python_backend_utils as pb_utils
+
+import os
+import numpy as np
+import torchaudio.compliance.kaldi as kaldi
+from cosyvoice.utils.file_utils import convert_onnx_to_trt
+from cosyvoice.utils.common import TrtContextWrapper
+import onnxruntime
+
+
+class TritonPythonModel:
+    """Triton Python model for audio tokenization.
+
+    This model takes reference audio input and extracts semantic tokens
+    using s3tokenizer.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        # Parse model parameters
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+
+        self.device = torch.device("cuda")
+
+        model_dir = model_params["model_dir"]
+        gpu = "l20"
+        enable_trt = True
+        if enable_trt:
+            self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
+                              f'{model_dir}/campplus.onnx',
+                              1,
+                              False)
+        else:
+            campplus_model = f'{model_dir}/campplus.onnx'
+            option = onnxruntime.SessionOptions()
+            option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+            option.intra_op_num_threads = 1
+            self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
+
+    def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
+        if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
+            trt_kwargs = self.get_spk_trt_kwargs()
+            convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
+        import tensorrt as trt
+        with open(spk_model, 'rb') as f:
+            spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
+        self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_spk_trt_kwargs(self):
+        min_shape = [(1, 4, 80)]
+        opt_shape = [(1, 500, 80)]
+        max_shape = [(1, 3000, 80)]
+        input_names = ["input"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def _extract_spk_embedding(self, speech):
+        feat = kaldi.fbank(speech,
+                           num_mel_bins=80,
+                           dither=0,
+                           sample_frequency=16000)
+        spk_feat = feat - feat.mean(dim=0, keepdim=True)
+
+        if isinstance(self.spk_model, onnxruntime.InferenceSession):
+            embedding = self.spk_model.run(
+                None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
+            )[0].flatten().tolist()
+            embedding = torch.tensor([embedding]).to(self.device)
+        else:
+            [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
+            # NOTE need to synchronize when switching stream
+            with torch.cuda.device(self.device):
+                torch.cuda.current_stream().synchronize()
+                spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
+                batch_size = spk_feat.size(0)
+
+                with stream:
+                    spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
+                    embedding = torch.empty((batch_size, 192), device=spk_feat.device)
+
+                    data_ptrs = [spk_feat.contiguous().data_ptr(),
+                                 embedding.contiguous().data_ptr()]
+                    for i, j in enumerate(data_ptrs):
+
+                        spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                    # run trt engine
+                    assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                    torch.cuda.current_stream().synchronize()
+                self.spk_model.release_estimator(spk_model, stream)
+
+        return embedding.half()
+
+    def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing tokenized outputs
+        """
+        responses = []
+        # Process each request in batch
+        for request in requests:
+            # Extract input tensors
+            wav_array = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav").as_numpy()
+            wav_array = torch.from_numpy(wav_array).to(self.device)
+
+            embedding = self._extract_spk_embedding(wav_array)
+
+            prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
+                "prompt_spk_embedding", to_dlpack(embedding))
+            inference_response = pb_utils.InferenceResponse(
+                output_tensors=[prompt_spk_embedding_tensor])
+
+            responses.append(inference_response)
+
+        return responses

+ 48 - 0
runtime/triton_trtllm/model_repo/speaker_embedding/config.pbtxt

@@ -0,0 +1,48 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "speaker_embedding"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+parameters [
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+  }
+]
+output [
+  {
+    name: "prompt_spk_embedding"
+    data_type: TYPE_FP16
+    dims: [-1]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 116 - 33
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -32,22 +32,27 @@ from typing import List, Dict
 
 import torch
 from torch.utils.dlpack import to_dlpack
+from torch.nn import functional as F
 
 import triton_python_backend_utils as pb_utils
 
 from hyperpyyaml import load_hyperpyyaml
+from cosyvoice.utils.common import fade_in_out
 from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
 from cosyvoice.utils.common import TrtContextWrapper
+from collections import defaultdict
+import numpy as np
 
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 logger = logging.getLogger(__name__)
 
 ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
 
 
 class CosyVoice2:
 
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
 
         self.model_dir = model_dir
         self.fp16 = fp16
@@ -57,7 +62,7 @@ class CosyVoice2:
             raise ValueError('{} not found!'.format(hyper_yaml_path))
         with open(hyper_yaml_path, 'r') as f:
             configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
-        self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16)
+        self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
         self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
         if load_jit:
             self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
@@ -73,14 +78,22 @@ class CosyVoice2Model:
     def __init__(self,
                  flow: torch.nn.Module,
                  hift: torch.nn.Module,
-                 fp16: bool = False):
-        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+                 fp16: bool = False,
+                 device: str = 'cuda'):
+        self.device = device
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
         if self.fp16 is True:
             self.flow.half()
 
+        # streaming tts config
+        self.token_hop_len = 25
+        self.mel_cache_len = 8
+        self.source_cache_len = int(self.mel_cache_len * 480)
+        self.speech_window = np.hamming(2 * self.source_cache_len)
+        self.hift_cache_dict = defaultdict(lambda: None)
+
     def load_jit(self, flow_encoder_model):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
@@ -111,6 +124,42 @@ class CosyVoice2Model:
         input_names = ["x", "mask", "mu", "cond"]
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
 
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
+        with torch.cuda.amp.autocast(self.fp16):
+            tts_mel, _ = self.flow.inference(token=token.to(self.device),
+                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_token=prompt_token.to(self.device),
+                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_feat=prompt_feat.to(self.device),
+                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                             embedding=embedding.to(self.device),
+                                             streaming=stream,
+                                             finalize=finalize)
+        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
+        # append hift cache
+        if self.hift_cache_dict[uuid] is not None:
+            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+        else:
+            hift_cache_source = torch.zeros(1, 1, 0)
+        # keep overlap mel and hift cache
+        if finalize is False:
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
+                                          'source': tts_source[:, :, -self.source_cache_len:],
+                                          'speech': tts_speech[:, -self.source_cache_len:]}
+            tts_speech = tts_speech[:, :-self.source_cache_len]
+        else:
+            if speed != 1.0:
+                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
+                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+        return tts_speech
+
 
 class TritonPythonModel:
     """Triton Python model for vocoder.
@@ -131,13 +180,19 @@ class TritonPythonModel:
         model_dir = model_params["model_dir"]
 
         # Initialize device and vocoder
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
 
         self.token2wav_model = CosyVoice2(
-            model_dir, load_jit=True, load_trt=True, fp16=True
+            model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
         )
 
+        spk_info_path = os.path.join(model_dir, "spk2info.pt")
+        if not os.path.exists(spk_info_path):
+            raise ValueError(f"spk2info.pt not found in {model_dir}")
+        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
+        self.default_spk_info = spk_info["001"]
+
         logger.info("Token2Wav initialized successfully")
 
     def execute(self, requests):
@@ -153,38 +208,66 @@ class TritonPythonModel:
         # Process each request in batch
         for request in requests:
             target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
-            prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens").as_numpy()
-            prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
-            prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
-
             target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
-            prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
-            prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
-            prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
+
+            prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
+            if prompt_speech_tokens_tensor is not None:
+                prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
+                prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
+                prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
+                prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
+                prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
+                prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
+                prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
+            else:
+                prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
+                prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
+                prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
 
             # shift the speech tokens according to the original vocab size
-            prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
             target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
 
-            tts_mel, _ = self.token2wav_model.model.flow.inference(
-                token=target_speech_tokens,
-                token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
-                    self.device
-                ),
-                prompt_token=prompt_speech_tokens,
-                prompt_token_len=torch.tensor(
-                    [prompt_speech_tokens.shape[1]], dtype=torch.int32
-                ).to(self.device),
-                prompt_feat=prompt_speech_feat,
-                prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
-                embedding=prompt_spk_embedding,
-                streaming=False,
-                finalize=True,
-            )
-
-            audio_hat, _ = self.token2wav_model.model.hift.inference(
-                speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
-            )
+            # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
+            token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
+            if token_offset is not None:
+                token_offset = token_offset.as_numpy().item()
+                finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
+                if not finalize:
+                    stream = True
+                else:
+                    stream = False
+                request_id = request.request_id()
+                audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
+                                                                 prompt_token=prompt_speech_tokens,
+                                                                 prompt_feat=prompt_speech_feat,
+                                                                 embedding=prompt_spk_embedding,
+                                                                 token_offset=token_offset,
+                                                                 uuid=request_id,
+                                                                 stream=stream,
+                                                                 finalize=finalize)
+                if finalize:
+                    self.token2wav_model.model.hift_cache_dict.pop(request_id)
+
+            else:
+                tts_mel, _ = self.token2wav_model.model.flow.inference(
+                    token=target_speech_tokens,
+                    token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
+                        self.device
+                    ),
+                    prompt_token=prompt_speech_tokens,
+                    prompt_token_len=torch.tensor(
+                        [prompt_speech_tokens.shape[1]], dtype=torch.int32
+                    ).to(self.device),
+                    prompt_feat=prompt_speech_feat,
+                    prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
+                    embedding=prompt_spk_embedding,
+                    streaming=False,
+                    finalize=True,
+                )
+
+                audio_hat, _ = self.token2wav_model.model.hift.inference(
+                    speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+                )
 
             generated_wave = audio_hat.squeeze(0).cpu().numpy()
 

+ 18 - 1
runtime/triton_trtllm/model_repo/token2wav/config.pbtxt

@@ -20,7 +20,7 @@ dynamic_batching {
 }
 parameters [
   {
-   key: "model_dir", 
+   key: "model_dir",
    value: {string_value:"${model_dir}"}
   }
 ]
@@ -35,16 +35,33 @@ input [
     name: "prompt_speech_tokens"
     data_type: TYPE_INT32
     dims: [-1]
+    optional: true
   },
   {
     name: "prompt_speech_feat"
     data_type: TYPE_FP16
     dims: [-1, 80]
+    optional: true
   },
   {
     name: "prompt_spk_embedding"
     data_type: TYPE_FP16
     dims: [-1]
+    optional: true
+  },
+  {
+    name: "token_offset"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "finalize"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
   }
 ]
 output [

+ 563 - 0
runtime/triton_trtllm/offline_inference.py

@@ -0,0 +1,563 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 offline_inference.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $model_scope_model_local_dir \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+"""
+
+import argparse
+import json
+import os
+import sys
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torchaudio
+from cosyvoice.utils.file_utils import load_wav
+from datasets import load_dataset
+from transformers import AutoTokenizer
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+import soundfile as sf
+import s3tokenizer
+from functools import partial
+import time
+
+from token2wav import CosyVoice2_Token2Wav
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+try:
+    torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+    pass
+
+
+def extract_speech_ids(speech_tokens_str):
+    """Extract speech IDs from token strings like <|s_23456|>"""
+    speech_ids = []
+    for token_str in speech_tokens_str:
+        if token_str.startswith('<|s_') and token_str.endswith('|>'):
+            num_str = token_str[4:-2]
+            num = int(num_str)
+            speech_ids.append(num)
+        else:
+            print(f"Unexpected token: {token_str}")
+    return speech_ids
+
+
+def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
+    """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
+    speech_id_str = ""
+    for token in cosy2_tokens:
+        speech_id_str += f"<|s_{token}|>"
+    return speech_id_str
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
+    parser.add_argument(
+        "--split-name",
+        type=str,
+        default="wenetspeech4tts",
+        help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
+    )
+    parser.add_argument(
+        "--output-dir", required=True, type=str, help="dir to save result"
+    )
+    parser.add_argument(
+        "--batch-size",
+        default=1,
+        type=int,
+        help="batch size (per-device) for inference",
+    )
+    parser.add_argument(
+        "--token2wav-batch-size",
+        default=1,
+        type=int,
+        help="batch size (per-device) for inference",
+    )
+    parser.add_argument(
+        "--num-workers", type=int, default=0, help="workers for dataloader"
+    )
+    parser.add_argument(
+        "--prefetch", type=int, default=None, help="prefetch for dataloader"
+    )
+    parser.add_argument(
+        "--llm-model-name-or-path",
+        required=True,
+        type=str,
+        help="LLM model path (includes both model and tokenizer)",
+    )
+    parser.add_argument(
+        "--token2wav-path",
+        required=True,
+        type=str,
+        help="CosyVoice2 token2wav model path",
+    )
+    parser.add_argument(
+        "--prompt-text",
+        type=str,
+        default=None,
+        help="The prompt text for CosyVoice2",
+    )
+    parser.add_argument(
+        "--prompt-speech-path",
+        type=str,
+        default=None,
+        help="The path to the prompt speech for CosyVoice2",
+    )
+    parser.add_argument(
+        "--top-p",
+        type=float,
+        default=0.95,
+        help="top p for sampling",
+    )
+    parser.add_argument(
+        "--temperature",
+        type=float,
+        default=0.8,
+        help="temperature for sampling",
+    )
+    parser.add_argument(
+        "--top-k",
+        type=int,
+        default=50,
+        help="top k for sampling",
+    )
+    parser.add_argument(
+        "--backend",
+        type=str,
+        default="hf",
+        choices=["hf", "trtllm", "vllm"],
+        help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
+    )
+    parser.add_argument(
+        "--engine-dir",
+        type=str,
+        default=None,
+        help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
+    )
+    parser.add_argument(
+        "--kv-cache-free-gpu-memory-fraction",
+        type=float,
+        default=0.6,
+        help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def data_collator(batch, tokenizer, s3_tokenizer):
+    """Simplified data collator for batch_size=1 processing"""
+    collator_start_time = time.time()
+    total_audio_processing_time = 0
+    total_speech_tokenization_time = 0
+    total_text_tokenization_time = 0
+
+    target_sample_rate = 16000  # CosyVoice2 uses 16kHz for prompt audio
+    device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
+    input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
+    prompt_text_after_apply_template_list = []
+    mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
+    for _, item in enumerate(batch):
+        audio_processing_start_time = time.time()
+        prompt_text, target_text = (
+            item["prompt_text"],
+            item["target_text"],
+        )
+        prompt_text_list.append(prompt_text)
+        full_text = prompt_text + target_text
+        full_text_list.append(full_text)
+        # remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
+        puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
+        for p in puncts:
+            if p in full_text:
+                full_text = full_text.replace(p, '')
+                print(f"removed {p} from {full_text}")
+
+        # get prompt audio for CosyVoice2 (convert to 16kHz)
+        ref_audio_org, ref_sr = (
+            item["prompt_audio"]["array"],
+            item["prompt_audio"]["sampling_rate"],
+        )
+        ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
+        print(ref_audio_org.shape)
+
+        if ref_sr != target_sample_rate:
+            resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
+            ref_audio = resampler(ref_audio_org)
+        else:
+            ref_audio = ref_audio_org
+
+        prompt_audio_list.append(ref_audio)
+        audio_processing_end_time = time.time()
+        total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
+
+        speech_tokenization_start_time = time.time()
+        if "prompt_audio_cosy2_tokens" in item:
+            prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
+            prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
+        else:
+            mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
+
+    if len(mels) > 0:
+        mels, mels_lens = s3tokenizer.padding(mels)
+        codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
+        for i in range(len(codes)):
+            prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
+    speech_tokenization_end_time = time.time()
+    total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
+
+    for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
+        text_tokenization_start_time = time.time()
+        prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
+        # Create chat template for LLM generation
+        chat = [
+            {"role": "user", "content": full_text_list[i]},
+            {"role": "assistant", "content": prompt_audio_cosy2_id_str}
+        ]
+
+        assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
+
+        input_ids = tokenizer.apply_chat_template(
+            chat,
+            tokenize=True,
+            return_tensors='pt',
+            continue_final_message=True
+        )
+        input_ids_list.append(input_ids.squeeze(0))
+
+        prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
+
+        prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
+        text_tokenization_end_time = time.time()
+        total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
+
+    ids = [item["id"] for item in batch]
+
+    return {
+        "input_ids": input_ids_list,
+        "ids": ids,
+        "prompt_text": prompt_text_list,
+        "prompt_audio_list": prompt_audio_list,
+        "prompt_text_after_apply_template": prompt_text_after_apply_template_list,
+        "audio_processing_time": total_audio_processing_time,
+        "speech_tokenization_time": total_speech_tokenization_time,
+        "text_tokenization_time": total_text_tokenization_time,
+    }
+
+
+def init_distributed():
+    world_size = int(os.environ.get("WORLD_SIZE", 1))
+    local_rank = int(os.environ.get("LOCAL_RANK", 0))
+    rank = int(os.environ.get("RANK", 0))
+    print(
+        "Inference on multiple gpus, this gpu {}".format(local_rank)
+        + ", rank {}, world_size {}".format(rank, world_size)
+    )
+    torch.cuda.set_device(local_rank)
+    dist.init_process_group("nccl")
+    return world_size, local_rank, rank
+
+
+def main(args):
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    assert torch.cuda.is_available()
+    local_rank, world_size, rank = 0, 1, 0
+    device = torch.device(f"cuda:{local_rank}")
+
+    tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
+
+    if args.backend == "hf":
+        model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
+        model.eval()
+        model.to(device)
+        runner = None
+    elif args.backend == "trtllm":
+        if args.engine_dir is None:
+            raise ValueError("--engine-dir is required when backend is 'trtllm'")
+
+        runtime_rank = tensorrt_llm.mpi_rank()
+        model = None
+
+        runner_kwargs = dict(
+            engine_dir=args.engine_dir,
+            rank=runtime_rank,
+            max_output_len=2048,
+            enable_context_fmha_fp32_acc=False,
+            max_batch_size=args.batch_size,
+            max_input_len=512,
+            kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
+            cuda_graph_mode=False,
+            gather_generation_logits=False,
+        )
+
+        runner = ModelRunnerCpp.from_dir(**runner_kwargs)
+    elif args.backend == "vllm":
+        model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
+        runner = None
+    else:
+        raise ValueError(f"Unsupported backend: {args.backend}")
+
+    token2wav_model = CosyVoice2_Token2Wav(
+        model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
+    )
+    if args.prompt_speech_path:
+        prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
+    else:
+        prompt_speech_16k = None
+    s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
+    dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
+    dataset = load_dataset(
+        dataset_name,
+        split=args.split_name,
+        trust_remote_code=True,
+    )
+
+    sampler = None
+    dataloader = DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        sampler=sampler,
+        shuffle=False,
+        num_workers=args.num_workers,
+        prefetch_factor=args.prefetch,
+        collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
+    )
+    for _ in range(3):
+        print(f"Running {_} times")
+        total_llm_time = 0
+        total_token2wav_time = 0
+        total_data_load_time = 0
+        total_llm_post_processing_time = 0
+        total_audio_save_time = 0
+        total_audio_processing_time_in_collator = 0
+        total_speech_tokenization_time_in_collator = 0
+        total_text_tokenization_time_in_collator = 0
+        total_audio_samples = 0
+        start_time = time.time()
+        total_steps = len(dataset)
+
+        if rank == 0:
+            progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
+
+        last_batch_end_time = time.time()
+        for batch in dataloader:
+            data_loaded_time = time.time()
+            total_data_load_time += data_loaded_time - last_batch_end_time
+            total_audio_processing_time_in_collator += batch["audio_processing_time"]
+            total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
+            total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
+            with torch.no_grad():
+                llm_start_time = time.time()
+                if args.backend == "hf":
+                    input_ids_list = batch["input_ids"]
+                    if len(input_ids_list) == 1:
+                        input_ids = input_ids_list[0].unsqueeze(0)
+                        attention_mask = torch.ones_like(input_ids)
+                    else:
+                        max_len = max([len(input_ids) for input_ids in input_ids_list])
+                        input_ids_list_new = [
+                            torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
+                            for input_ids in input_ids_list
+                        ]
+                        input_ids = torch.stack(input_ids_list_new)
+                        attention_mask = torch.zeros_like(input_ids)
+                        for i in range(len(input_ids_list)):
+                            attention_mask[i, :len(input_ids_list[i])] = 1
+
+                    input_ids = input_ids.to(device)
+
+                    outputs = model.generate(
+                        input_ids=input_ids.to(device),
+                        attention_mask=attention_mask.to(device),
+                        max_new_tokens=2048,
+                        do_sample=True,
+                        top_p=args.top_p,
+                        temperature=args.temperature,
+                        repetition_penalty=1.1,
+                        top_k=args.top_k,
+                    )
+                    torch.cuda.synchronize()
+                elif args.backend == "trtllm":
+                    batch_input_ids = list(batch["input_ids"])
+                    input_lengths = [x.size(0) for x in batch_input_ids]
+
+                    end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
+                    print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
+                    outputs = runner.generate(
+                        batch_input_ids=batch_input_ids,
+                        max_new_tokens=2048,
+                        end_id=end_id,
+                        pad_id=end_id,
+                        temperature=args.temperature,
+                        top_k=args.top_k,
+                        top_p=args.top_p,
+                        repetition_penalty=1.1,
+                        num_return_sequences=1,
+                        streaming=False,
+                        output_sequence_lengths=True,
+                        output_generation_logits=False,
+                        return_dict=True,
+                        return_all_generated_tokens=False
+                    )
+                    torch.cuda.synchronize()
+                    output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
+                    num_output_sents, num_beams, _ = output_ids.size()
+                    assert num_beams == 1
+                    beam = 0
+                    batch_size = len(batch["input_ids"])
+                    num_return_sequences = num_output_sents // batch_size
+                    assert num_return_sequences == 1
+                    outputs = []
+                    for i in range(batch_size * num_return_sequences):
+                        batch_idx = i // num_return_sequences
+                        seq_idx = i % num_return_sequences
+                        output_begin = input_lengths[batch_idx]
+                        output_end = sequence_lengths[i][beam]
+                        outputs_i = output_ids[i][beam][:output_end].tolist()
+                        outputs.append(outputs_i)
+                elif args.backend == "vllm":
+                    input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
+                    sampling_params = SamplingParams(
+                        temperature=args.temperature,
+                        top_p=args.top_p,
+                        top_k=args.top_k,
+                        repetition_penalty=1.1,
+                        max_tokens=2048,
+                    )
+                    outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
+                    print(outputs)
+                    for j, output in enumerate(outputs):
+                        outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
+
+                llm_end_time = time.time()
+                total_llm_time += (llm_end_time - llm_start_time)
+
+                items_for_token_2wav = []
+                for i in range(len(batch["ids"])):
+                    llm_post_processing_start_time = time.time()
+                    input_length = len(batch["input_ids"][i])
+                    generated_ids = outputs[i][input_length:]
+                    speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+                    speech_ids = extract_speech_ids(speech_tokens_str)
+                    print(i, speech_ids)
+                    if len(speech_ids) == 0:
+                        print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
+                        continue
+
+                    if args.prompt_text is not None:
+                        current_prompt_text = args.prompt_text
+                        current_prompt_audio = prompt_speech_16k
+                    else:
+                        current_prompt_text = batch["prompt_text"][i]
+                        current_prompt_audio = batch["prompt_audio_list"][i]
+
+                    llm_post_processing_end_time = time.time()
+                    total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
+                    if current_prompt_audio is not None:
+                        items_for_token_2wav.append({
+                            "speech_ids": speech_ids,
+                            "prompt_audio": current_prompt_audio.squeeze(0),
+                            "id": batch["ids"][i]
+                        })
+                    else:
+                        print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
+
+                for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
+                    t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
+                    if not t2w_batch:
+                        continue
+
+                    t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
+                    t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
+                    t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
+                    t2w_ids = [item["id"] for item in t2w_batch]
+
+                    token2wav_start_time = time.time()
+                    generated_wavs = token2wav_model(
+                        t2w_generated_speech_tokens_list,
+                        t2w_prompt_audios_list,
+                        t2w_prompt_audios_sample_rate,
+                    )
+                    torch.cuda.synchronize()
+                    token2wav_end_time = time.time()
+                    total_token2wav_time += (token2wav_end_time - token2wav_start_time)
+
+                    audio_save_start_time = time.time()
+                    for j, audio_hat in enumerate(generated_wavs):
+                        generated_wave = audio_hat.squeeze().cpu().numpy()
+                        total_audio_samples += len(generated_wave)
+                        target_sample_rate = 24000
+
+                        utt = t2w_ids[j]
+                        sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
+                        print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
+                    audio_save_end_time = time.time()
+                    total_audio_save_time += audio_save_end_time - audio_save_start_time
+
+            if rank == 0:
+                progress_bar.update(world_size * len(batch["ids"]))
+
+            last_batch_end_time = time.time()
+        if rank == 0:
+            progress_bar.close()
+            end_time = time.time()
+            target_sample_rate = 24000
+            total_audio_duration_seconds = total_audio_samples / target_sample_rate
+
+            log_file_path = os.path.join(args.output_dir, "log.txt")
+            with open(log_file_path, 'w') as f:
+                args_dict = vars(args)
+                log_data = {
+                    "args": args_dict,
+                    "data_load_time_seconds": total_data_load_time,
+                    "audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
+                    "speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
+                    "text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
+                    "llm_time_seconds": total_llm_time,
+                    "llm_post_processing_time_seconds": total_llm_post_processing_time,
+                    "token2wav_time_seconds": total_token2wav_time,
+                    "audio_save_time_seconds": total_audio_save_time,
+                    "total_audio_duration_seconds": total_audio_duration_seconds,
+                    "pipeline_time_seconds": end_time - start_time,
+                }
+                print(log_data)
+                f.write(json.dumps(log_data, indent=4))
+            print(f"Metrics logged to {log_file_path}")
+
+
+if __name__ == "__main__":
+    args = get_args()
+    if args.backend == "vllm":
+        from vllm import LLM, SamplingParams
+    elif args.backend == "trtllm":
+        import tensorrt_llm
+        from tensorrt_llm.runtime import ModelRunnerCpp
+    elif args.backend == "hf":
+        from transformers import AutoModelForCausalLM
+    else:
+        raise ValueError(f"Unsupported backend: {args.backend}")
+    main(args)

+ 46 - 10
runtime/triton_trtllm/run.sh

@@ -15,6 +15,8 @@ trt_engines_dir=./trt_engines_${trt_dtype}
 
 model_repo=./model_repo_cosyvoice2
 
+use_spk2info_cache=False
+
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
     echo "Cloning CosyVoice"
     git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
@@ -25,8 +27,11 @@ fi
 
 if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
     echo "Downloading CosyVoice2-0.5B"
+    # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
     huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
     modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
+    # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
+    wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
 fi
 
 
@@ -57,9 +62,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
     cosyvoice2_dir="cosyvoice2"
 
     cp -r ./model_repo/${cosyvoice2_dir} $model_repo
-    cp -r ./model_repo/audio_tokenizer $model_repo
     cp -r ./model_repo/tensorrt_llm $model_repo
     cp -r ./model_repo/token2wav $model_repo
+    if [ $use_spk2info_cache == "False" ]; then
+        cp -r ./model_repo/audio_tokenizer $model_repo
+        cp -r ./model_repo/speaker_embedding $model_repo
+    fi
 
     ENGINE_PATH=$trt_engines_dir
     MAX_QUEUE_DELAY_MICROSECONDS=0
@@ -67,13 +75,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
     LLM_TOKENIZER_DIR=$huggingface_model_local_dir
     BLS_INSTANCE_NUM=4
     TRITON_MAX_BATCH_SIZE=16
-    DECOUPLED_MODE=False
+    DECOUPLED_MODE=True # True for streaming, False for offline
 
     python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-    python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
     python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
     python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
-
+    if [ $use_spk2info_cache == "False" ]; then
+        python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+        python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    fi
 fi
 
 if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
@@ -82,7 +92,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
 fi
 
 if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
-    echo "Single request test http"
+    echo "Single request test http, only work for offline TTS mode"
     python3 client_http.py \
         --reference-audio ./assets/prompt_audio.wav \
         --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
@@ -93,14 +103,40 @@ fi
 if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
     echo "Running benchmark client grpc"
     num_task=4
-    # set mode=streaming, when decoupled=True
-    # set mode=offline, when decoupled=False
-    mode=offline
+
+    mode=streaming
+    BLS_INSTANCE_NUM=4
+
     python3 client_grpc.py \
         --server-addr localhost \
         --model-name cosyvoice2 \
         --num-tasks $num_task \
         --mode $mode \
+        --use-spk2info-cache $use_spk2info_cache \
         --huggingface-dataset yuekai/seed_tts_cosy2 \
-        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype}
-fi
+        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  echo "stage 6: Offline inference benchmark"
+  n_gpus=1
+  datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
+  backend=trtllm # hf, trtllm, vllm
+
+  batch_sizes=(16 8 4 2 1)
+  token2wav_batch_size=1
+  for batch_size in ${batch_sizes[@]}; do
+    for dataset in ${datasets[@]}; do
+    output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 offline_inference.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $model_scope_model_local_dir \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+    done
+  done
+fi

+ 335 - 0
runtime/triton_trtllm/token2wav.py

@@ -0,0 +1,335 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 token2wav.py --enable-trt || exit 1
+"""
+import torch
+from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
+from flashcosyvoice.modules.hifigan import HiFTGenerator
+from flashcosyvoice.utils.audio import mel_spectrogram
+import torchaudio.compliance.kaldi as kaldi
+import onnxruntime
+import s3tokenizer
+from torch.utils.data import DataLoader
+from datasets import load_dataset
+import torchaudio
+import os
+import logging
+import argparse
+import queue
+import time
+
+
+def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
+    import tensorrt as trt
+    logging.info("Converting onnx to trt...")
+    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+    logger = trt.Logger(trt.Logger.INFO)
+    builder = trt.Builder(logger)
+    network = builder.create_network(network_flags)
+    parser = trt.OnnxParser(network, logger)
+    config = builder.create_builder_config()
+    # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB
+    if fp16:
+        config.set_flag(trt.BuilderFlag.FP16)
+    profile = builder.create_optimization_profile()
+    # load onnx model
+    with open(onnx_model, "rb") as f:
+        if not parser.parse(f.read()):
+            for error in range(parser.num_errors):
+                print(parser.get_error(error))
+            raise ValueError('failed to parse {}'.format(onnx_model))
+    # set input shapes
+    for i in range(len(trt_kwargs['input_names'])):
+        profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
+    tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
+    # set input and output data type
+    for i in range(network.num_inputs):
+        input_tensor = network.get_input(i)
+        input_tensor.dtype = tensor_dtype
+    for i in range(network.num_outputs):
+        output_tensor = network.get_output(i)
+        output_tensor.dtype = tensor_dtype
+    config.add_optimization_profile(profile)
+    engine_bytes = builder.build_serialized_network(network, config)
+    # save trt engine
+    with open(trt_model, "wb") as f:
+        f.write(engine_bytes)
+    logging.info("Succesfully convert onnx to trt...")
+
+
+class TrtContextWrapper:
+    def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        self.trt_engine = trt_engine
+        self.device = device
+        for _ in range(trt_concurrent):
+            trt_context = trt_engine.create_execution_context()
+            trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
+            assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
+            self.trt_context_pool.put([trt_context, trt_stream])
+        assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
+
+    def acquire_estimator(self):
+        return self.trt_context_pool.get(), self.trt_engine
+
+    def release_estimator(self, context, stream):
+        self.trt_context_pool.put([context, stream])
+
+
+class CosyVoice2_Token2Wav(torch.nn.Module):
+    def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
+        super().__init__()
+        self.device_id = device_id
+        self.device = f"cuda:{device_id}"
+
+        self.flow = CausalMaskedDiffWithXvec()
+        self.flow.half()
+        self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
+        self.flow.to(self.device).eval()
+
+        self.hift = HiFTGenerator()
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
+
+        self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
+
+        gpu = "l20"
+        if enable_trt:
+            self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
+                          f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
+                          1,
+                          True)
+            self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
+                              f'{model_dir}/campplus.onnx',
+                              1,
+                              False)
+
+    def forward_spk_embedding(self, spk_feat):
+        if isinstance(self.spk_model, onnxruntime.InferenceSession):
+            return self.spk_model.run(
+                None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
+            )[0].flatten().tolist()
+        else:
+            [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
+            # NOTE need to synchronize when switching stream
+            with torch.cuda.device(self.device_id):
+                torch.cuda.current_stream().synchronize()
+                spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
+                batch_size = spk_feat.size(0)
+
+                with stream:
+                    spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
+                    output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
+
+                    data_ptrs = [spk_feat.contiguous().data_ptr(),
+                                 output_tensor.contiguous().data_ptr()]
+                    for i, j in enumerate(data_ptrs):
+
+                        spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                    # run trt engine
+                    assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                    torch.cuda.current_stream().synchronize()
+                self.spk_model.release_estimator(spk_model, stream)
+
+            return output_tensor.cpu().numpy().flatten().tolist()
+
+    def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
+        if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
+            trt_kwargs = self.get_spk_trt_kwargs()
+            convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
+        import tensorrt as trt
+        with open(spk_model, 'rb') as f:
+            spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
+        self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_spk_trt_kwargs(self):
+        min_shape = [(1, 4, 80)]
+        opt_shape = [(1, 500, 80)]
+        max_shape = [(1, 3000, 80)]
+        input_names = ["input"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
+        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
+        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
+            trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
+            convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
+        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
+        opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
+        max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
+                     (max_batch_size * 2, 80)]
+        input_names = ["x", "mask", "mu", "cond", "t", "spks"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
+        prompt_speech_tokens_list, prompt_speech_mels_list = [], []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            log_mel = s3tokenizer.log_mel_spectrogram(audio)  # [num_mels, T]
+            prompt_speech_mels_list.append(log_mel)
+        prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
+        prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
+            prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
+        )
+        for i in range(len(prompt_speech_tokens)):
+            speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
+            prompt_speech_tokens_list.append(speech_tokens_i)
+        return prompt_speech_tokens_list
+
+    def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
+        spk_emb_for_flow = []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
+            spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
+            spk_emb = self.forward_spk_embedding(spk_feat)
+
+            spk_emb_for_flow.append(spk_emb)
+        spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
+        return spk_emb_for_flow
+
+    def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
+        prompt_mels_for_flow = []
+        prompt_mels_lens_for_flow = []
+        for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
+            assert len(audio.shape) == 1
+            audio = audio.unsqueeze(0)
+            if sample_rate != 24000:
+                audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
+            mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0)  # [T, num_mels]
+            mel_len = mel.shape[0]
+            prompt_mels_for_flow.append(mel)
+            prompt_mels_lens_for_flow.append(mel_len)
+        prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0)  # [B, T', num_mels=80]
+        prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
+        return prompt_mels_for_flow, prompt_mels_lens_for_flow
+
+    def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
+                     prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
+        batch_size = prompt_mels_for_flow.shape[0]
+        flow_inputs = []
+        flow_inputs_lens = []
+        for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
+            flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
+            flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
+
+        flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
+        flow_inputs_lens = torch.tensor(flow_inputs_lens)
+
+        with torch.amp.autocast(self.device, dtype=torch.float16):
+            generated_mels, generated_mels_lens = self.flow(
+                flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
+                prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
+                streaming=False, finalize=True
+            )
+
+        return generated_mels, generated_mels_lens
+
+    def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
+        batch_size = generated_mels.shape[0]
+        generated_wavs = []
+        for i in range(batch_size):
+            mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
+            wav, _ = self.hift(speech_feat=mel)
+            generated_wavs.append(wav)
+        return generated_wavs
+
+    @torch.inference_mode()
+    def forward(
+        self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
+    ):
+        # assert all item in prompt_audios_sample_rate is 16000
+        assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
+
+        prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
+
+        prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
+
+        spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
+
+        generated_mels, generated_mels_lens = self.forward_flow(
+            prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
+
+        generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
+
+        return generated_wavs
+
+
+def collate_fn(batch):
+    ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
+    for _, item in enumerate(batch):
+        generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
+        audio = torch.from_numpy(item['prompt_audio']['array']).float()
+        prompt_audios_list.append(audio)
+        prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
+        ids.append(item['id'])
+
+    return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
+
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--enable-trt", action="store_true")
+    parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B")
+    parser.add_argument("--batch-size", type=int, default=4)
+    parser.add_argument("--output-dir", type=str, default="generated_wavs")
+    parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
+    parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    args = get_args()
+    model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
+    # mkdir output_dir if not exists
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    dataset_name = "yuekai/seed_tts_cosy2"
+
+    dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
+
+    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
+
+    for _ in range(args.warmup):
+        start_time = time.time()
+
+        for batch in data_loader:
+            ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
+
+            generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
+
+            for id, wav in zip(ids, generated_wavs):
+                torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
+
+        end_time = time.time()
+        epoch_time = end_time - start_time
+        print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")