Browse Source

add cosyvoice2 offline inference

root 5 months ago
parent
commit
cc1991870b

+ 11 - 0
README.md

@@ -246,6 +246,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o
 cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
 ```
 
+#### Using Nvidia TensorRT-LLM for deployment
+
+Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
+To quick start:
+
+``` sh
+cd runtime/triton_trtllm
+docker compose up -d
+```
+For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
+
 ## Discussion & Communication
 
 You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).

+ 34 - 4
runtime/triton_trtllm/README.md

@@ -1,4 +1,4 @@
-## Serving CosyVoice with NVIDIA Triton Inference Server
+## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
 
 Contributed by Yuekai Zhang (NVIDIA).
 
@@ -41,6 +41,7 @@ bash run.sh <start_stage> <stop_stage> [service_type]
 - **Stage 3**: Launches the Triton Inference Server.
 - **Stage 4**: Runs the single-utterance HTTP client for testing.
 - **Stage 5**: Runs the gRPC benchmark client.
+- **Stage 6**: Runs the offline inference benchmark test.
 
 ### Export Models and Launch Server
 
@@ -59,7 +60,7 @@ Sends a single HTTP inference request. This is intended for testing the offline
 bash run.sh 4 4
 ```
 
-### Benchmark with a Dataset
+### Benchmark with client-server mode
 
 To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
 ```sh
@@ -71,10 +72,26 @@ bash run.sh 5 5 # [streaming|offline]
 > [!TIP]
 > It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
 
+### Benchmark with offline inference mode
+For offline inference mode benchmark, please check the below command:
+```sh
+# install FlashCosyVoice for token2wav batching
+# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
+# cd /workspace/FlashCosyVoice
+# pip install -e . 
+# cd -
+# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
+
+bash run.sh 6 6
+
+# You can also switch to huggingface backend by setting backend=hf
+```
+
+
 ### Benchmark Results
 The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
 
-**Streaming TTS (First Chunk Latency)**
+**Client-Server Mode: Streaming TTS (First Chunk Latency)**
 | Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
 |---|---|---|---|---|
 | Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
@@ -86,13 +103,26 @@ The following results were obtained by decoding on a single L20 GPU with 26 prom
 
 > If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
 
-**Offline TTS (Full Sentence Latency)**
+**Client-Server Mode: Offline TTS (Full Sentence Latency)**
 | Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
 |---|---|---|---|---|---|
 | Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
 | Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
 | Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
 
+**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
+| Backend | Batch Size | llm_time_seconds  | total_time_seconds | RTF |
+|---------|------------|------------------|-----------------------|--|
+| HF | 1 | 39.26 |  44.31 | 0.2494 |
+| HF | 2 | 30.54 | 35.62 | 0.2064 |
+| HF | 4 | 18.63 |  23.90 | 0.1421 |
+| HF | 8 | 11.22 | 16.45 | 0.0947 | 
+| HF | 16 | 8.42 | 13.78 | 0.0821 |
+| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
+| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
+| TRTLLM | 4 | 4.89 |  9.38 | 0.0539 |
+| TRTLLM | 8 | 2.92 |  7.23 | 0.0418 |
+| TRTLLM | 16 | 2.01 |  6.63 | 0.0386 |
 ### OpenAI-Compatible Server
 
 To launch an OpenAI-compatible API service, run the following commands:

+ 605 - 0
runtime/triton_trtllm/offline_inference.py

@@ -0,0 +1,605 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 offline_inference.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $model_scope_model_local_dir \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+"""
+
+import argparse
+import json
+import os
+import sys
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torchaudio
+from cosyvoice.utils.file_utils import load_wav
+from datasets import load_dataset
+from transformers import AutoTokenizer
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+import soundfile as sf
+import s3tokenizer
+from functools import partial
+import time
+
+from token2wav import CosyVoice2_Token2Wav
+
+sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
+try:
+    torch.multiprocessing.set_start_method("spawn")
+except RuntimeError:
+    pass
+
+
+def extract_speech_ids(speech_tokens_str):
+    """Extract speech IDs from token strings like <|s_23456|>"""
+    speech_ids = []
+    for token_str in speech_tokens_str:
+        if token_str.startswith('<|s_') and token_str.endswith('|>'):
+            num_str = token_str[4:-2]
+            num = int(num_str)
+            speech_ids.append(num)
+        else:
+            print(f"Unexpected token: {token_str}")
+    return speech_ids
+
+def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
+    """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
+    speech_id_str = ""
+    for token in cosy2_tokens:
+        speech_id_str += f"<|s_{token}|>"
+    return speech_id_str
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
+    parser.add_argument(
+        "--split-name",
+        type=str,
+        default="wenetspeech4tts",
+        help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
+    )
+    parser.add_argument(
+        "--output-dir", required=True, type=str, help="dir to save result"
+    )
+    parser.add_argument(
+        "--batch-size",
+        default=1,
+        type=int,
+        help="batch size (per-device) for inference",
+    )
+    parser.add_argument(
+        "--token2wav-batch-size",
+        default=1,
+        type=int,
+        help="batch size (per-device) for inference",
+    )
+    parser.add_argument(
+        "--num-workers", type=int, default=0, help="workers for dataloader"
+    )
+    parser.add_argument(
+        "--prefetch", type=int, default=None, help="prefetch for dataloader"
+    )
+    parser.add_argument(
+        "--llm-model-name-or-path",
+        required=True,
+        type=str,
+        help="LLM model path (includes both model and tokenizer)",
+    )
+    parser.add_argument(
+        "--token2wav-path",
+        required=True,
+        type=str,
+        help="CosyVoice2 token2wav model path",
+    )
+    parser.add_argument(
+        "--prompt-text",
+        type=str,
+        default=None,
+        help="The prompt text for CosyVoice2",
+    )
+    parser.add_argument(
+        "--prompt-speech-path",
+        type=str,
+        default=None,
+        help="The path to the prompt speech for CosyVoice2",
+    )
+    parser.add_argument(
+        "--top-p",
+        type=float,
+        default=0.95,
+        help="top p for sampling",
+    )
+    parser.add_argument(
+        "--temperature",
+        type=float,
+        default=0.8,
+        help="temperature for sampling",
+    )
+    parser.add_argument(
+        "--top-k",
+        type=int,
+        default=50,
+        help="top k for sampling",
+    )
+    parser.add_argument(
+        "--backend",
+        type=str,
+        default="hf",
+        choices=["hf", "trtllm", "vllm"],
+        help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
+    )
+    parser.add_argument(
+        "--engine-dir",
+        type=str,
+        default=None,
+        help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
+    )
+    parser.add_argument(
+        "--kv-cache-free-gpu-memory-fraction",
+        type=float,
+        default=0.6,
+        help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
+    )
+    args = parser.parse_args()
+    return args
+
+
+
+def data_collator(batch, tokenizer, s3_tokenizer):
+    """Simplified data collator for batch_size=1 processing"""
+    collator_start_time = time.time()
+    total_audio_processing_time = 0
+    total_speech_tokenization_time = 0
+    total_text_tokenization_time = 0
+
+    target_sample_rate = 16000  # CosyVoice2 uses 16kHz for prompt audio
+    device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
+    input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
+    prompt_text_after_apply_template_list = []
+    mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
+    for i, item in enumerate(batch):
+        audio_processing_start_time = time.time()
+        prompt_text, target_text = (
+            item["prompt_text"],
+            item["target_text"],
+        )
+        prompt_text_list.append(prompt_text)
+        full_text = prompt_text + target_text
+        full_text_list.append(full_text)
+        # remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
+        puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
+        for p in puncts:
+            if p in full_text:
+                full_text = full_text.replace(p, '')
+                print(f"removed {p} from {full_text}")
+
+        # get prompt audio for CosyVoice2 (convert to 16kHz)
+        ref_audio_org, ref_sr = (
+            item["prompt_audio"]["array"],
+            item["prompt_audio"]["sampling_rate"],
+        )
+        ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
+        # ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
+        print(ref_audio_org.shape)
+
+        if ref_sr != target_sample_rate:
+            resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
+            ref_audio = resampler(ref_audio_org)
+        else:
+            ref_audio = ref_audio_org
+
+        prompt_audio_list.append(ref_audio)
+        audio_processing_end_time = time.time()
+        total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
+
+        speech_tokenization_start_time = time.time()
+        if "prompt_audio_cosy2_tokens" in item:
+            prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
+            prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
+        else:
+            # convert to float first
+            mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
+
+    if len(mels) > 0:
+        mels, mels_lens = s3tokenizer.padding(mels)
+        codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
+        for i in range(len(codes)):
+            prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
+    speech_tokenization_end_time = time.time()
+    total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
+
+    for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
+        text_tokenization_start_time = time.time()
+        prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
+        # Create chat template for LLM generation
+        chat = [
+            {"role": "user", "content": full_text_list[i]},
+            {"role": "assistant", "content": prompt_audio_cosy2_id_str}
+        ]
+
+        assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
+
+        input_ids = tokenizer.apply_chat_template(
+            chat,
+            tokenize=True,
+            return_tensors='pt',
+            continue_final_message=True
+        )
+        input_ids_list.append(input_ids.squeeze(0))
+
+        prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
+
+        prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
+        text_tokenization_end_time = time.time()
+        total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
+
+    ids = [item["id"] for item in batch]
+
+    return {
+        "input_ids": input_ids_list,
+        "ids": ids,
+        "prompt_text": prompt_text_list,
+        "prompt_audio_list": prompt_audio_list,
+        "prompt_text_after_apply_template": prompt_text_after_apply_template_list,
+        "audio_processing_time": total_audio_processing_time,
+        "speech_tokenization_time": total_speech_tokenization_time,
+        "text_tokenization_time": total_text_tokenization_time,
+    }
+
+
+def init_distributed():
+    world_size = int(os.environ.get("WORLD_SIZE", 1))
+    local_rank = int(os.environ.get("LOCAL_RANK", 0))
+    rank = int(os.environ.get("RANK", 0))
+    print(
+        "Inference on multiple gpus, this gpu {}".format(local_rank)
+        + ", rank {}, world_size {}".format(rank, world_size)
+    )
+    torch.cuda.set_device(local_rank)
+    dist.init_process_group("nccl")
+    return world_size, local_rank, rank
+
+
+def main(args):
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    assert torch.cuda.is_available()
+    # world_size, local_rank, rank = init_distributed()
+    local_rank, world_size, rank = 0, 1, 0
+    device = torch.device(f"cuda:{local_rank}")
+
+    # Load tokenizer
+    tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
+
+    # model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
+    # Initialize backend based on argument
+    if args.backend == "hf":
+        # Load HuggingFace model
+        model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
+        model.eval()
+        model.to(device)
+        runner = None
+    elif args.backend == "trtllm":
+        # Validate engine_dir is provided
+        if args.engine_dir is None:
+            raise ValueError("--engine-dir is required when backend is 'trtllm'")
+        # import tensorrt_llm
+        #from tensorrt_llm.runtime import ModelRunnerCpp
+
+        # Initialize TensorRT-LLM runner
+        runtime_rank = tensorrt_llm.mpi_rank()
+        model = None
+
+        # Prepare input for runner initialization
+        runner_kwargs = dict(
+            engine_dir=args.engine_dir,
+            rank=runtime_rank,
+            max_output_len=2048,
+            enable_context_fmha_fp32_acc=False,
+            max_batch_size=args.batch_size,
+            max_input_len=512,
+            kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
+            cuda_graph_mode=False,
+            gather_generation_logits=False,
+        )
+
+        runner = ModelRunnerCpp.from_dir(**runner_kwargs)
+    elif args.backend == "vllm":
+        # from vllm import LLM, SamplingParams
+        model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
+        runner = None
+    else:
+        raise ValueError(f"Unsupported backend: {args.backend}")
+
+    token2wav_model = CosyVoice2_Token2Wav(
+        model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
+    )
+    if args.prompt_speech_path:
+        prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
+    else:
+        prompt_speech_16k = None
+    s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
+    dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
+    dataset = load_dataset(
+        dataset_name,
+        split=args.split_name,
+        trust_remote_code=True,
+    )
+
+    # sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
+    sampler = None
+    dataloader = DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        sampler=sampler,
+        shuffle=False,
+        num_workers=args.num_workers,
+        prefetch_factor=args.prefetch,
+        collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
+    )
+    for _ in range(3):
+        print(f"Running {_} times")
+        total_llm_time = 0
+        total_token2wav_time = 0
+        total_data_load_time = 0
+        total_llm_post_processing_time = 0
+        total_audio_save_time = 0
+        total_audio_processing_time_in_collator = 0
+        total_speech_tokenization_time_in_collator = 0
+        total_text_tokenization_time_in_collator = 0
+        total_audio_samples = 0
+        start_time = time.time()
+        total_steps = len(dataset)
+
+        if rank == 0:
+            progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
+
+        last_batch_end_time = time.time()
+        for batch in dataloader:
+            data_loaded_time = time.time()
+            total_data_load_time += data_loaded_time - last_batch_end_time
+            total_audio_processing_time_in_collator += batch["audio_processing_time"]
+            total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
+            total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
+            with torch.no_grad():
+                # Generate speech tokens using LLM
+                llm_start_time = time.time()
+                if args.backend == "hf":
+                    input_ids_list = batch["input_ids"]
+                    if len(input_ids_list) == 1:
+                        input_ids = input_ids_list[0].unsqueeze(0)
+                        attention_mask = torch.ones_like(input_ids)
+                    else:
+                        # Handle batch > 1 if needed
+                        max_len = max([len(input_ids) for input_ids in input_ids_list])
+                        # input_ids_list_new = [
+                        #     torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
+                        #     for input_ids in input_ids_list
+                        # ]
+                        input_ids_list_new = [
+                            torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
+                            for input_ids in input_ids_list
+                        ]
+                        input_ids = torch.stack(input_ids_list_new)
+                        # compute attention mask
+                        attention_mask = torch.zeros_like(input_ids)
+                        for i in range(len(input_ids_list)):
+                            attention_mask[i, :len(input_ids_list[i])] = 1
+
+                        # breakpoint()
+
+
+                    input_ids = input_ids.to(device)
+
+                    outputs = model.generate(
+                        input_ids=input_ids.to(device),
+                        attention_mask=attention_mask.to(device),
+                        max_new_tokens=2048,  # Max length for generation
+                        do_sample=True,
+                        top_p=args.top_p,
+                        temperature=args.temperature,
+                        repetition_penalty=1.1,
+                        top_k=args.top_k,
+                    )
+                    torch.cuda.synchronize()
+                elif args.backend == "trtllm":
+                    # Convert input_ids to list of tensors for TensorRT-LLM
+                    batch_input_ids = [ids for ids in batch["input_ids"]]
+                    input_lengths = [x.size(0) for x in batch_input_ids]
+
+                    # Get end_id from tokenizer
+                    end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
+                    print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
+                    # random_seed=42,                         repetition_penalty=1.1,
+                    outputs = runner.generate(
+                        batch_input_ids=batch_input_ids,
+                        max_new_tokens=2048,
+                        end_id=end_id,
+                        pad_id=end_id,
+                        temperature=args.temperature,
+                        top_k=args.top_k,
+                        top_p=args.top_p,
+                        repetition_penalty=1.1,
+                        num_return_sequences=1,
+                        streaming=False,
+                        output_sequence_lengths=True,
+                        output_generation_logits=False,
+                        return_dict=True,
+                        return_all_generated_tokens=False
+                    )
+                    torch.cuda.synchronize()
+                    # Extract output_ids from TensorRT-LLM output
+                    output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
+                    num_output_sents, num_beams, _ = output_ids.size()
+                    assert num_beams == 1
+                    beam = 0
+                    batch_size = len(batch["input_ids"])
+                    num_return_sequences = num_output_sents // batch_size
+                    assert num_return_sequences == 1
+                    outputs = []
+                    for i in range(batch_size * num_return_sequences):
+                        batch_idx = i // num_return_sequences
+                        seq_idx = i % num_return_sequences
+                        # inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist()
+                        # input_text = tokenizer.decode(inputs)
+                        # print(f'Input [Text {batch_idx}]: \"{input_text}\"')
+                        output_begin = input_lengths[batch_idx]
+                        output_end = sequence_lengths[i][beam]
+                        # outputs_i = output_ids[i][beam][output_begin:output_end].tolist()
+                        outputs_i = output_ids[i][beam][:output_end].tolist()
+                        outputs.append(outputs_i)
+                elif args.backend == "vllm":
+                    input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
+                    # prompts = [batch["prompt_text_after_apply_template"][i] for i in range(len(batch["prompt_text_after_apply_template"]))]
+                    # print(prompts)
+                    sampling_params = SamplingParams(
+                        temperature=args.temperature,
+                        top_p=args.top_p,
+                        top_k=args.top_k,
+                        repetition_penalty=1.1,
+                        max_tokens=2048,
+                    )
+                    outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
+                    # outputs = model.generate(prompts, sampling_params)
+                    print(outputs)
+                    # breakpoint()
+                    for j, output in enumerate(outputs):
+                        outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
+
+                llm_end_time = time.time()
+                total_llm_time += (llm_end_time - llm_start_time)
+
+                items_for_token2wav = []
+                for i in range(len(batch["ids"])):
+                    llm_post_processing_start_time = time.time()
+                    # Extract generated tokens (excluding input)
+                    input_length = len(batch["input_ids"][i])
+                    generated_ids = outputs[i][input_length:]  # Remove last token if needed
+                    speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+                    # Extract speech IDs from token strings like <|s_23456|>
+                    speech_ids = extract_speech_ids(speech_tokens_str)
+                    print(i, speech_ids)
+                    # breakpoint()
+                    if len(speech_ids) == 0:
+                        print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
+                        continue
+
+                    if args.prompt_text is not None:
+                        current_prompt_text = args.prompt_text
+                        current_prompt_audio = prompt_speech_16k
+                    else:
+                        current_prompt_text = batch["prompt_text"][i]
+                        current_prompt_audio = batch["prompt_audio_list"][i]
+
+                    llm_post_processing_end_time = time.time()
+                    total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
+                    if current_prompt_audio is not None:
+                        items_for_token2wav.append({
+                            "speech_ids": speech_ids,
+                            "prompt_audio": current_prompt_audio.squeeze(0),
+                            "id": batch["ids"][i]
+                        })
+                    else:
+                        print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
+
+                for i in range(0, len(items_for_token2wav), args.token2wav_batch_size):
+                    t2w_batch = items_for_token2wav[i:i + args.token2wav_batch_size]
+                    if not t2w_batch:
+                        continue
+
+                    t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
+                    t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
+                    t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
+                    t2w_ids = [item["id"] for item in t2w_batch]
+
+                    # Generate audio using CosyVoice2
+                    token2wav_start_time = time.time()
+                    generated_wavs = token2wav_model(
+                        t2w_generated_speech_tokens_list,
+                        t2w_prompt_audios_list,
+                        t2w_prompt_audios_sample_rate,
+                    )
+                    torch.cuda.synchronize()
+                    token2wav_end_time = time.time()
+                    total_token2wav_time += (token2wav_end_time - token2wav_start_time)
+
+                    audio_save_start_time = time.time()
+                    # Convert to numpy and save
+                    for j, audio_hat in enumerate(generated_wavs):
+                        generated_wave = audio_hat.squeeze().cpu().numpy()
+                        total_audio_samples += len(generated_wave)
+                        target_sample_rate = 24000
+
+                        utt = t2w_ids[j]
+                        sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
+                        print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
+                    audio_save_end_time = time.time()
+                    total_audio_save_time += audio_save_end_time - audio_save_start_time
+
+            if rank == 0:
+                progress_bar.update(world_size * len(batch["ids"]))
+
+            last_batch_end_time = time.time()
+        if rank == 0:
+            progress_bar.close()
+            end_time = time.time()
+            target_sample_rate = 24000
+            total_audio_duration_seconds = total_audio_samples / target_sample_rate
+
+            log_file_path = os.path.join(args.output_dir, "log.txt")
+            with open(log_file_path, 'w') as f:
+                # Convert Namespace to dict for JSON serialization
+                args_dict = vars(args)
+                log_data = {
+                    "args": args_dict,
+                    "data_load_time_seconds": total_data_load_time,
+                    "audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
+                    "speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
+                    "text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
+                    "llm_time_seconds": total_llm_time,
+                    "llm_post_processing_time_seconds": total_llm_post_processing_time,
+                    "token2wav_time_seconds": total_token2wav_time,
+                    "audio_save_time_seconds": total_audio_save_time,
+                    "total_audio_duration_seconds": total_audio_duration_seconds,
+                    "pipeline_time_seconds": end_time - start_time,
+                }
+                print(log_data)
+                f.write(json.dumps(log_data, indent=4))
+            print(f"Metrics logged to {log_file_path}")
+
+
+if __name__ == "__main__":
+    args = get_args()
+    if args.backend == "vllm":
+        from vllm import LLM, SamplingParams
+    elif args.backend == "trtllm":
+        import tensorrt_llm
+        from tensorrt_llm.runtime import ModelRunnerCpp
+    elif args.backend == "hf":
+        from transformers import AutoModelForCausalLM
+    else:
+        raise ValueError(f"Unsupported backend: {args.backend}")
+    main(args)

+ 25 - 0
runtime/triton_trtllm/run.sh

@@ -27,6 +27,7 @@ fi
 
 if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
     echo "Downloading CosyVoice2-0.5B"
+    # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
     huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
     modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
     # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
@@ -115,3 +116,27 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
         --huggingface-dataset yuekai/seed_tts_cosy2 \
         --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
 fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  echo "stage 6: Offline inference benchmark"
+  n_gpus=1
+  datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
+  backend=trtllm # hf, trtllm, vllm
+
+  batch_sizes=(16 8 4 2 1)
+  token2wav_batch_size=1
+  for batch_size in ${batch_sizes[@]}; do
+    for dataset in ${datasets[@]}; do
+    output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 offline_inference.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $model_scope_model_local_dir \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+    done
+  done
+fi

+ 336 - 0
runtime/triton_trtllm/token2wav.py

@@ -0,0 +1,336 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 token2wav.py --enable-trt || exit 1
+"""
+import torch
+from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
+from flashcosyvoice.modules.hifigan import HiFTGenerator
+from flashcosyvoice.utils.audio import mel_spectrogram
+import torchaudio.compliance.kaldi as kaldi
+import onnxruntime
+import s3tokenizer
+from torch.utils.data import DataLoader
+from datasets import load_dataset
+import torchaudio
+import os
+import logging
+import argparse
+import queue
+import time
+
+
+def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
+    import tensorrt as trt
+    logging.info("Converting onnx to trt...")
+    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+    logger = trt.Logger(trt.Logger.INFO)
+    builder = trt.Builder(logger)
+    network = builder.create_network(network_flags)
+    parser = trt.OnnxParser(network, logger)
+    config = builder.create_builder_config()
+    # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB
+    if fp16:
+        config.set_flag(trt.BuilderFlag.FP16)
+    profile = builder.create_optimization_profile()
+    # load onnx model
+    with open(onnx_model, "rb") as f:
+        if not parser.parse(f.read()):
+            for error in range(parser.num_errors):
+                print(parser.get_error(error))
+            raise ValueError('failed to parse {}'.format(onnx_model))
+    # set input shapes
+    for i in range(len(trt_kwargs['input_names'])):
+        profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
+    tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
+    # set input and output data type
+    for i in range(network.num_inputs):
+        input_tensor = network.get_input(i)
+        input_tensor.dtype = tensor_dtype
+    for i in range(network.num_outputs):
+        output_tensor = network.get_output(i)
+        output_tensor.dtype = tensor_dtype
+    config.add_optimization_profile(profile)
+    engine_bytes = builder.build_serialized_network(network, config)
+    # save trt engine
+    with open(trt_model, "wb") as f:
+        f.write(engine_bytes)
+    logging.info("Succesfully convert onnx to trt...")
+
+class TrtContextWrapper:
+    def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        self.trt_engine = trt_engine
+        self.device = device
+        for _ in range(trt_concurrent):
+            trt_context = trt_engine.create_execution_context()
+            trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
+            assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
+            self.trt_context_pool.put([trt_context, trt_stream])
+        assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
+
+    def acquire_estimator(self):
+        return self.trt_context_pool.get(), self.trt_engine
+
+    def release_estimator(self, context, stream):
+        self.trt_context_pool.put([context, stream])
+
+class CosyVoice2_Token2Wav(torch.nn.Module):
+    def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
+        super().__init__()
+        self.device_id = device_id
+        self.device = f"cuda:{device_id}"
+        
+        self.flow = CausalMaskedDiffWithXvec()
+        self.flow.half()
+        self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
+        self.flow.to(self.device).eval()
+
+        self.hift = HiFTGenerator()
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
+                                                    providers=["CPUExecutionProvider"])
+        
+        self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
+
+        gpu="l20"
+        if enable_trt:
+            self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
+                                f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
+                                1,
+                                True)
+            self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
+                                f'{model_dir}/campplus.onnx',
+                                1,
+                                False)
+
+
+    def forward_spk_embedding(self, spk_feat):
+        if isinstance(self.spk_model, onnxruntime.InferenceSession):
+            return self.spk_model.run(
+                None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
+            )[0].flatten().tolist()
+        else:
+            [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
+            # NOTE need to synchronize when switching stream
+            with torch.cuda.device(self.device_id):
+                torch.cuda.current_stream().synchronize()
+                spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
+                batch_size = spk_feat.size(0)
+
+                with stream:
+                    spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
+                    output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
+
+                    data_ptrs = [spk_feat.contiguous().data_ptr(),
+                                 output_tensor.contiguous().data_ptr()]
+                    for i, j in enumerate(data_ptrs):
+
+                        spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                    # run trt engine
+                    assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                    torch.cuda.current_stream().synchronize()
+                self.spk_model.release_estimator(spk_model, stream)
+
+            return output_tensor.cpu().numpy().flatten().tolist()
+
+    def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
+        if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
+            trt_kwargs = self.get_spk_trt_kwargs()
+            convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
+        import tensorrt as trt
+        with open(spk_model, 'rb') as f:
+            spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
+        self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_spk_trt_kwargs(self):
+        min_shape = [(1, 4, 80)]
+        opt_shape = [(1, 500, 80)]
+        max_shape = [(1, 3000, 80)]
+        input_names = ["input"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
+        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
+        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
+            trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=2, max_batch_size=16)
+            convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64):
+        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
+        opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
+        max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
+        input_names = ["x", "mask", "mu", "cond", "t", "spks"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
+        prompt_speech_tokens_list, prompt_speech_mels_list = [], []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            log_mel = s3tokenizer.log_mel_spectrogram(audio)  # [num_mels, T]
+            prompt_speech_mels_list.append(log_mel)
+        prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
+        prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
+            prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
+        )
+        for i in range(len(prompt_speech_tokens)):
+            speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
+            prompt_speech_tokens_list.append(speech_tokens_i)
+        return prompt_speech_tokens_list
+    
+    def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
+        spk_emb_for_flow = []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
+            spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
+            spk_emb = self.forward_spk_embedding(spk_feat)
+
+            spk_emb_for_flow.append(spk_emb)
+        spk_emb_for_flow = torch.tensor(spk_emb_for_flow)    
+        return spk_emb_for_flow
+    
+    def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
+        prompt_mels_for_flow = []
+        prompt_mels_lens_for_flow = []
+        for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
+            assert len(audio.shape) == 1
+            audio = audio.unsqueeze(0)
+            if sample_rate != 24000:
+                audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
+            mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0)  # [T, num_mels]
+            mel_len = mel.shape[0]
+            prompt_mels_for_flow.append(mel)
+            prompt_mels_lens_for_flow.append(mel_len)
+        prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0)  # [B, T', num_mels=80]
+        prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
+        return prompt_mels_for_flow, prompt_mels_lens_for_flow
+    
+
+    def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
+        batch_size = prompt_mels_for_flow.shape[0]
+        flow_inputs = []
+        flow_inputs_lens = []
+        for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
+            flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
+            flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
+
+        flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
+        flow_inputs_lens = torch.tensor(flow_inputs_lens)
+
+        with torch.amp.autocast(self.device, dtype=torch.float16):
+            generated_mels, generated_mels_lens = self.flow(
+                flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
+                prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
+                streaming=False, finalize=True
+            )
+
+        return generated_mels, generated_mels_lens
+
+    def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
+        batch_size = generated_mels.shape[0]
+        generated_wavs = []
+        for i in range(batch_size):
+            mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
+            wav, _ = self.hift(speech_feat=mel)
+            generated_wavs.append(wav)
+        return generated_wavs
+
+
+    @torch.inference_mode()
+    def forward(
+        self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
+    ):
+        # assert all item in prompt_audios_sample_rate is 16000
+        assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
+        
+
+        prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
+
+        prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
+
+        spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
+
+        generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
+
+        generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
+        
+        return generated_wavs
+
+
+def collate_fn(batch):
+    ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
+    for i, item in enumerate(batch):
+        generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
+        audio = torch.from_numpy(item['prompt_audio']['array']).float() 
+        prompt_audios_list.append(audio)
+        prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
+        ids.append(item['id'])
+
+    return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--enable-trt", action="store_true")
+    parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B")
+    parser.add_argument("--batch-size", type=int, default=4)
+    parser.add_argument("--output-dir", type=str, default="generated_wavs")
+    parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
+    parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
+    return parser.parse_args()
+
+if __name__ == "__main__":
+    args = get_args()
+    model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
+    # mkdir output_dir if not exists
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    dataset_name = "yuekai/seed_tts_cosy2"
+
+    dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
+
+
+    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
+    
+    
+    for epoch in range(args.warmup):
+        start_time = time.time()
+        
+        for batch in data_loader:
+            ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
+
+            generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
+            
+
+            for id, wav in zip(ids, generated_wavs):
+                torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
+        
+        end_time = time.time()
+        epoch_time = end_time - start_time
+        print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")