Prechádzať zdrojové kódy

init step-audio2 token2wav

yuekaiz 2 mesiacov pred
rodič
commit
b207c60885

+ 455 - 0
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py

@@ -0,0 +1,455 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import math
+import os
+import re
+import threading
+import time
+from typing import Dict, List, Tuple, Optional, Union
+
+import numpy as np
+import torch
+from torch.utils.dlpack import from_dlpack, to_dlpack
+import triton_python_backend_utils as pb_utils
+from transformers import AutoTokenizer
+
+import torchaudio
+
+
+from matcha.utils.audio import mel_spectrogram
+
+ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
+
+
+class TritonPythonModel:
+    """Triton Python model for Spark TTS.
+
+    This model orchestrates the end-to-end TTS pipeline by coordinating
+    between audio tokenizer, LLM, and vocoder components.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        self.logger = pb_utils.Logger
+        # Parse model parameters
+        self.model_config = json.loads(args['model_config'])
+        parameters = self.model_config['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+        self.logger.log_info(f"model_params:{model_params}")
+        self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
+        self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
+
+        # Initialize tokenizer
+        llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
+        self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
+        self.prompt_template = "<|sos|>{input_text}<|task_id|>"
+        self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
+
+        self.device = torch.device("cuda")
+        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
+
+        self.token_frame_rate = 25
+        self.flow_pre_lookahead_len = 3
+        self.token_hop_len = 15
+
+        spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
+        if not os.path.exists(spk_info_path):
+            raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
+        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
+        self.default_spk_info = spk_info["001"]
+
+    def forward_llm(self, input_ids):
+        """
+        Prepares the response from the language model based on the provided
+        inputs. Creates a `pb_utils.InferenceRequest` object with passed
+        `llm_request_inputs` to send to a decoupled TensorRTLLM model.
+        For each response from the language model:
+            - Checks for errors and raise an exception if any are found.
+            - Extracts the "output_ids" tensor from the response.
+            - Determines the finish reason based on the presence of the
+              end-of-sequence token or reaching the maximum length.
+            - Appends the generated token IDs to `output_ids`.
+            - If the finish reason is determined, decodes the output IDs to text
+              and prepares the final response.
+
+        The final response includes the generated text, finish reason,
+        completion tokens, prompt tokens, and total tokens.
+
+        Parameters
+        ----------
+        - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
+
+        Returns
+        -------
+        - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
+        """
+        # convert input_ids to numpy, with shape [1, sequence_length]
+        input_ids = input_ids.cpu().numpy()
+        max_tokens = 750
+        input_dict = {
+            "request_output_len": np.array([[max_tokens]], dtype=np.int32),
+            "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
+            "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
+            "streaming": np.array([[self.decoupled]], dtype=np.bool_),
+            "runtime_top_p": np.array([[0.95]], dtype=np.float32),
+            "runtime_top_k": np.array([[50]], dtype=np.int32),
+            "temperature": np.array([[0.8]], dtype=np.float32),
+            "repetition_penalty": np.array([[1.1]], dtype=np.float32),
+            "random_seed": np.array([[42]], dtype=np.uint64),
+            "input_ids": input_ids,
+            "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
+        }
+
+        # Convert inputs to Triton tensors
+        input_tensor_list = [
+            pb_utils.Tensor(k, v) for k, v in input_dict.items()
+        ]
+
+        # Create and execute inference request
+        llm_request = pb_utils.InferenceRequest(
+            model_name="tensorrt_llm",
+            requested_output_names=["output_ids", "sequence_length"],
+            inputs=input_tensor_list,
+        )
+
+        llm_responses = llm_request.exec(decoupled=self.decoupled)
+        if self.decoupled:
+            for llm_response in llm_responses:
+                if llm_response.has_error():
+                    raise pb_utils.TritonModelException(llm_response.error().message())
+
+                # Extract and process output
+                output_ids = pb_utils.get_output_tensor_by_name(
+                    llm_response, "output_ids").as_numpy()
+                seq_lens = pb_utils.get_output_tensor_by_name(
+                    llm_response, "sequence_length").as_numpy()
+
+                # Get actual output IDs up to the sequence length
+                actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
+
+                yield actual_output_ids
+        else:
+            llm_response = llm_responses
+            if llm_response.has_error():
+                raise pb_utils.TritonModelException(llm_response.error().message())
+
+            # Extract and process output
+            output_ids = pb_utils.get_output_tensor_by_name(
+                llm_response, "output_ids").as_numpy()
+            seq_lens = pb_utils.get_output_tensor_by_name(
+                llm_response, "sequence_length").as_numpy()
+
+            # Get actual output IDs up to the sequence length
+            actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
+
+            yield actual_output_ids
+
+    def forward_audio_tokenizer(self, wav, wav_len):
+        """Forward pass through the audio tokenizer component.
+
+        Args:
+            wav: Input waveform tensor
+            wav_len: Waveform length tensor
+
+        Returns:
+            Tuple of global and semantic tokens
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='audio_tokenizer',
+            requested_output_names=['prompt_speech_tokens'],
+            inputs=[wav, wav_len]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
+        prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
+
+        return prompt_speech_tokens
+
+    def forward_speaker_embedding(self, wav):
+        """Forward pass through the speaker embedding component.
+
+        Args:
+            wav: Input waveform tensor
+
+        Returns:
+            Prompt speaker embedding tensor
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='speaker_embedding',
+            requested_output_names=['prompt_spk_embedding'],
+            inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
+        prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
+
+        return prompt_spk_embedding
+
+    def forward_token2wav(
+            self,
+            target_speech_tokens: torch.Tensor,
+            request_id: str,
+            prompt_speech_tokens: torch.Tensor = None,
+            prompt_speech_feat: torch.Tensor = None,
+            prompt_spk_embedding: torch.Tensor = None,
+            token_offset: int = None,
+            finalize: bool = None) -> torch.Tensor:
+        """Forward pass through the vocoder component.
+
+        Args:
+            prompt_speech_tokens: Prompt speech tokens tensor
+            prompt_speech_feat: Prompt speech feat tensor
+            prompt_spk_embedding: Prompt spk embedding tensor
+            target_speech_tokens: Target speech tokens tensor
+
+        Returns:
+            Generated waveform tensor
+        """
+        target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
+
+        inputs_tensor = [target_speech_tokens_tensor]
+
+        if token_offset is not None:
+            assert finalize is not None
+            token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
+            finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
+            inputs_tensor.append(token_offset_tensor)
+            inputs_tensor.append(finalize_tensor)
+
+        if prompt_spk_embedding is not None:
+            assert prompt_speech_feat is not None
+            prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
+            prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
+            prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
+            inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
+
+        # Create and execute inference request
+        inference_request = pb_utils.InferenceRequest(
+            model_name='token2wav',
+            requested_output_names=['waveform'],
+            inputs=inputs_tensor,
+            request_id=request_id,
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output waveform
+        waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
+        waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
+
+        return waveform
+
+    def parse_input(self, text, prompt_text, prompt_speech_tokens):
+        total_text = f"{prompt_text}{text}"
+        prompt = self.prompt_template.format(input_text=total_text)
+        input_ids = self.tokenizer.encode(prompt)
+        input_ids = torch.tensor([input_ids], dtype=torch.int32)
+        input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
+        return input_ids
+
+    def _extract_speech_feat(self, speech):
+        speech_feat = mel_spectrogram(
+            speech,
+            n_fft=1920,
+            num_mels=80,
+            sampling_rate=24000,
+            hop_size=480,
+            win_size=1920,
+            fmin=0,
+            fmax=8000).squeeze(
+            dim=0).transpose(
+            0,
+            1).to(
+                self.device)
+        speech_feat = speech_feat.unsqueeze(dim=0)
+        return speech_feat
+
+    def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
+        for generated_ids in generated_ids_iter:
+            generated_ids = generated_ids.tolist()
+            if len(generated_ids) == 0:
+                break
+            semantic_token_ids_arr.extend(generated_ids)
+        llm_is_done_flag[0] = True
+
+    def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated audio
+        """
+        responses = []
+
+        for request in requests:
+            request_id = request.request_id()
+            # Extract input tensors
+            wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
+
+            # Process reference audio through audio tokenizer
+            if wav is not None:
+                wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+                prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
+                prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
+
+                wav_tensor = wav.as_numpy()
+                wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
+                prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
+                speech_feat = self._extract_speech_feat(prompt_speech_resample)
+                token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
+                prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
+                prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
+
+                reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+                reference_text = reference_text[0][0].decode('utf-8')
+                prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+            else:
+                # using pre-cached reference text
+                reference_text = self.default_spk_info["prompt_text"]
+                prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
+                prompt_speech_feat = None
+                prompt_spk_embedding = None
+
+            target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
+            target_text = target_text[0][0].decode('utf-8')
+
+            # Prepare prompt for LLM
+            input_ids = self.parse_input(
+                text=target_text,
+                prompt_text=reference_text,
+                prompt_speech_tokens=prompt_speech_tokens,
+            )
+
+            # Generate semantic tokens with LLM
+            generated_ids_iter = self.forward_llm(input_ids)
+
+            if self.decoupled:
+                response_sender = request.get_response_sender()
+
+                semantic_token_ids_arr = []
+                llm_is_done_flag = [False]
+
+                llm_thread = threading.Thread(
+                    target=self._llm_gen_thread,
+                    args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
+                )
+
+                llm_thread.start()
+
+                token_offset, chunk_index = 0, 0
+                start_time = time.time()
+                this_token_hop_len = self.token_hop_len
+
+                while True:
+                    pending_num = len(semantic_token_ids_arr) - token_offset
+
+                    if llm_is_done_flag[0]:
+                        break
+
+                    if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
+                        this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
+                        this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
+
+                        sub_tts_speech = self.forward_token2wav(
+                            this_tts_speech_token, request_id, prompt_speech_tokens,
+                            prompt_speech_feat, prompt_spk_embedding, token_offset, False
+                        )
+
+                        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+                        inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                        response_sender.send(inference_response)
+
+                        token_offset += this_token_hop_len
+                        self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
+
+                        if self.dynamic_chunk_strategy == "exponential":
+                            this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
+                        elif self.dynamic_chunk_strategy == "time_based":
+                            # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
+                            cost_time = time.time() - start_time
+                            duration = token_offset / self.token_frame_rate
+                            if chunk_index > 0 and cost_time > 0:
+                                avg_chunk_processing_time = cost_time / (chunk_index + 1)
+                                if avg_chunk_processing_time > 0:
+                                    multiples = (duration - cost_time) / avg_chunk_processing_time
+                                    self.logger.log_info(f"multiples: {multiples}")
+                                    next_pending_num = len(semantic_token_ids_arr) - token_offset
+                                    if multiples > 4:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
+                                    elif multiples > 2:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
+                                    else:
+                                        this_token_hop_len = self.token_hop_len
+                                    this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
+                        chunk_index += 1
+                    else:
+                        time.sleep(0.02)
+
+                this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
+                sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
+                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+                inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                response_sender.send(inference_response)
+
+                llm_thread.join()
+                response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+                self.logger.log_info("send tritonserver_response_complete_final to end")
+            else:
+                generated_ids = next(generated_ids_iter)
+                generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
+                if generated_ids is None or len(generated_ids) == 0:
+                    raise pb_utils.TritonModelException("Generated IDs is None or empty")
+
+                audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
+
+                # Prepare response
+                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
+                inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                responses.append(inference_response)
+
+        if not self.decoupled:
+            return responses

+ 73 - 0
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt

@@ -0,0 +1,73 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "cosyvoice2"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+model_transaction_policy {
+  decoupled: ${decoupled_mode}
+}
+parameters [
+  {
+   key: "llm_tokenizer_dir",
+   value: {string_value:"${llm_tokenizer_dir}"}
+  },
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "reference_wav"
+    data_type: TYPE_FP32
+    dims: [-1]
+    optional: true
+  },
+  {
+    name: "reference_wav_len"
+    data_type: TYPE_INT32
+    dims: [1]
+    optional: true
+  },
+  {
+    name: "reference_text"
+    data_type: TYPE_STRING
+    dims: [1]
+    optional: true
+  },
+  {
+    name: "target_text"
+    data_type: TYPE_STRING
+    dims: [1]
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: ${bls_instance_num}
+    kind: KIND_CPU
+  }
+]

+ 278 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py

@@ -0,0 +1,278 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import os
+
+import logging
+from typing import List, Dict
+
+import torch
+from torch.utils.dlpack import to_dlpack
+from torch.nn import functional as F
+
+import triton_python_backend_utils as pb_utils
+
+from hyperpyyaml import load_hyperpyyaml
+from cosyvoice.utils.common import fade_in_out
+from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
+from cosyvoice.utils.common import TrtContextWrapper
+from collections import defaultdict
+import numpy as np
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
+
+
+class CosyVoice2:
+
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
+
+        self.model_dir = model_dir
+        self.fp16 = fp16
+
+        hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
+        if not os.path.exists(hyper_yaml_path):
+            raise ValueError('{} not found!'.format(hyper_yaml_path))
+        with open(hyper_yaml_path, 'r') as f:
+            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
+        self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
+        self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
+        if load_jit:
+            self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
+        if load_trt:
+            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
+                                trt_concurrent,
+                                self.fp16)
+
+
+class CosyVoice2Model:
+
+    def __init__(self,
+                 flow: torch.nn.Module,
+                 hift: torch.nn.Module,
+                 fp16: bool = False,
+                 device: str = 'cuda'):
+        self.device = device
+        self.flow = flow
+        self.hift = hift
+        self.fp16 = fp16
+        if self.fp16 is True:
+            self.flow.half()
+
+        # streaming tts config
+        self.token_hop_len = 25
+        self.mel_cache_len = 8
+        self.source_cache_len = int(self.mel_cache_len * 480)
+        self.speech_window = np.hamming(2 * self.source_cache_len)
+        self.hift_cache_dict = defaultdict(lambda: None)
+
+    def load_jit(self, flow_encoder_model):
+        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
+        self.flow.encoder = flow_encoder
+
+    def load(self, flow_model, hift_model):
+        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
+        self.flow.to(self.device).eval()
+        # in case hift_model is a hifigan model
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
+        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
+        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
+            convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_trt_kwargs(self):
+        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
+        opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
+        max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
+        input_names = ["x", "mask", "mu", "cond"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
+        with torch.cuda.amp.autocast(self.fp16):
+            tts_mel, _ = self.flow.inference(token=token.to(self.device),
+                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_token=prompt_token.to(self.device),
+                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_feat=prompt_feat.to(self.device),
+                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                             embedding=embedding.to(self.device),
+                                             streaming=stream,
+                                             finalize=finalize)
+        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
+        # append hift cache
+        if self.hift_cache_dict[uuid] is not None:
+            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+        else:
+            hift_cache_source = torch.zeros(1, 1, 0)
+        # keep overlap mel and hift cache
+        if finalize is False:
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
+                                          'source': tts_source[:, :, -self.source_cache_len:],
+                                          'speech': tts_speech[:, -self.source_cache_len:]}
+            tts_speech = tts_speech[:, :-self.source_cache_len]
+        else:
+            if speed != 1.0:
+                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
+                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+        return tts_speech
+
+
+class TritonPythonModel:
+    """Triton Python model for vocoder.
+
+    This model takes global and semantic tokens as input and generates audio waveforms
+    using the BiCodec vocoder.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        # Parse model parameters
+        parameters = json.loads(args['model_config'])['parameters']
+        model_params = {key: value["string_value"] for key, value in parameters.items()}
+        model_dir = model_params["model_dir"]
+
+        # Initialize device and vocoder
+        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+        logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
+
+        self.token2wav_model = CosyVoice2(
+            model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
+        )
+
+        spk_info_path = os.path.join(model_dir, "spk2info.pt")
+        if not os.path.exists(spk_info_path):
+            raise ValueError(f"spk2info.pt not found in {model_dir}")
+        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
+        self.default_spk_info = spk_info["001"]
+
+        logger.info("Token2Wav initialized successfully")
+
+    def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated waveforms
+        """
+        responses = []
+        # Process each request in batch
+        for request in requests:
+            target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
+            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
+
+            prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
+            if prompt_speech_tokens_tensor is not None:
+                prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
+                prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
+                prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
+                prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
+                prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
+                prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
+                prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
+            else:
+                prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
+                prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
+                prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
+
+            # shift the speech tokens according to the original vocab size
+            target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
+
+            # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
+            token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
+            if token_offset is not None:
+                token_offset = token_offset.as_numpy().item()
+                finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
+                if not finalize:
+                    stream = True
+                else:
+                    stream = False
+                request_id = request.request_id()
+                audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
+                                                                 prompt_token=prompt_speech_tokens,
+                                                                 prompt_feat=prompt_speech_feat,
+                                                                 embedding=prompt_spk_embedding,
+                                                                 token_offset=token_offset,
+                                                                 uuid=request_id,
+                                                                 stream=stream,
+                                                                 finalize=finalize)
+                if finalize:
+                    self.token2wav_model.model.hift_cache_dict.pop(request_id)
+
+            else:
+                tts_mel, _ = self.token2wav_model.model.flow.inference(
+                    token=target_speech_tokens,
+                    token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
+                        self.device
+                    ),
+                    prompt_token=prompt_speech_tokens,
+                    prompt_token_len=torch.tensor(
+                        [prompt_speech_tokens.shape[1]], dtype=torch.int32
+                    ).to(self.device),
+                    prompt_feat=prompt_speech_feat,
+                    prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
+                    embedding=prompt_spk_embedding,
+                    streaming=False,
+                    finalize=True,
+                )
+
+                audio_hat, _ = self.token2wav_model.model.hift.inference(
+                    speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
+                )
+
+            generated_wave = audio_hat.squeeze(0).cpu().numpy()
+
+            wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
+            inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
+            responses.append(inference_response)
+
+        return responses

+ 80 - 0
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt

@@ -0,0 +1,80 @@
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: "token2wav"
+backend: "python"
+max_batch_size: ${triton_max_batch_size}
+dynamic_batching {
+    max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+}
+parameters [
+  {
+   key: "model_dir",
+   value: {string_value:"${model_dir}"}
+  }
+]
+
+input [
+  {
+    name: "target_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+  },
+  {
+    name: "prompt_speech_tokens"
+    data_type: TYPE_INT32
+    dims: [-1]
+    optional: true
+  },
+  {
+    name: "prompt_speech_feat"
+    data_type: TYPE_FP16
+    dims: [-1, 80]
+    optional: true
+  },
+  {
+    name: "prompt_spk_embedding"
+    data_type: TYPE_FP16
+    dims: [-1]
+    optional: true
+  },
+  {
+    name: "token_offset"
+    data_type: TYPE_INT32
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  },
+  {
+    name: "finalize"
+    data_type: TYPE_BOOL
+    dims: [ 1 ]
+    reshape: { shape: [ ] }
+    optional: true
+  }
+]
+output [
+  {
+    name: "waveform"
+    data_type: TYPE_FP32
+    dims: [ -1 ]
+  }
+]
+
+instance_group [
+  {
+    count: 1
+    kind: KIND_CPU
+  }
+]

+ 142 - 0
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh

@@ -0,0 +1,142 @@
+#!/bin/bash
+# Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
+export CUDA_VISIBLE_DEVICES=0
+cosyvoice_path=/workspace/CosyVoice
+export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
+export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
+stage=$1
+stop_stage=$2
+
+huggingface_model_local_dir=./cosyvoice2_llm
+model_scope_model_local_dir=./CosyVoice2-0.5B
+trt_dtype=bfloat16
+trt_weights_dir=./trt_weights_${trt_dtype}
+trt_engines_dir=./trt_engines_${trt_dtype}
+
+model_repo=./model_repo_cosyvoice2
+
+use_spk2info_cache=False
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+    echo "Cloning CosyVoice"
+    git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
+    cd $cosyvoice_path
+    git submodule update --init --recursive
+    cd runtime/triton_trtllm
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+    echo "Downloading CosyVoice2-0.5B"
+    # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
+    huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
+    modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
+    # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
+    wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
+fi
+
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+    echo "Converting checkpoint to TensorRT weights"
+    python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \
+                                --output_dir $trt_weights_dir \
+                                --dtype $trt_dtype || exit 1
+
+    echo "Building TensorRT engines"
+    trtllm-build --checkpoint_dir $trt_weights_dir \
+                --output_dir $trt_engines_dir \
+                --max_batch_size 16 \
+                --max_num_tokens 32768 \
+                --gemm_plugin $trt_dtype || exit 1
+
+    echo "Testing TensorRT engines"
+    python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \
+                    --tokenizer_dir $huggingface_model_local_dir \
+                    --top_k 50 --top_p 0.95 --temperature 0.8 \
+                    --engine_dir=$trt_engines_dir  || exit 1
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+    echo "Creating model repository"
+    rm -rf $model_repo
+    mkdir -p $model_repo
+    cosyvoice2_dir="cosyvoice2"
+
+    cp -r ./model_repo/${cosyvoice2_dir} $model_repo
+    cp -r ./model_repo/tensorrt_llm $model_repo
+    cp -r ./model_repo/token2wav $model_repo
+    if [ $use_spk2info_cache == "False" ]; then
+        cp -r ./model_repo/audio_tokenizer $model_repo
+        cp -r ./model_repo/speaker_embedding $model_repo
+    fi
+
+    ENGINE_PATH=$trt_engines_dir
+    MAX_QUEUE_DELAY_MICROSECONDS=0
+    MODEL_DIR=$model_scope_model_local_dir
+    LLM_TOKENIZER_DIR=$huggingface_model_local_dir
+    BLS_INSTANCE_NUM=4
+    TRITON_MAX_BATCH_SIZE=16
+    DECOUPLED_MODE=True # True for streaming, False for offline
+
+    python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
+    if [ $use_spk2info_cache == "False" ]; then
+        python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+        python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+   echo "Starting Triton server"
+   tritonserver --model-repository $model_repo
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+    echo "Single request test http, only work for offline TTS mode"
+    python3 client_http.py \
+        --reference-audio ./assets/prompt_audio.wav \
+        --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
+        --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
+        --model-name cosyvoice2
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+    echo "Running benchmark client grpc"
+    num_task=4
+
+    mode=streaming
+    BLS_INSTANCE_NUM=4
+
+    python3 client_grpc.py \
+        --server-addr localhost \
+        --model-name cosyvoice2 \
+        --num-tasks $num_task \
+        --mode $mode \
+        --use-spk2info-cache $use_spk2info_cache \
+        --huggingface-dataset yuekai/seed_tts_cosy2 \
+        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+  echo "stage 6: Offline inference benchmark"
+  n_gpus=1
+  datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
+  backend=trtllm # hf, trtllm, vllm
+
+  batch_sizes=(16 8 4 2 1)
+  token2wav_batch_size=1
+  for batch_size in ${batch_sizes[@]}; do
+    for dataset in ${datasets[@]}; do
+    output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 offline_inference.py \
+            --output-dir $output_dir \
+            --llm-model-name-or-path $huggingface_model_local_dir \
+            --token2wav-path $model_scope_model_local_dir \
+            --backend $backend \
+            --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
+            --engine-dir $trt_engines_dir \
+            --split-name ${dataset} || exit 1
+    done
+  done
+fi

+ 496 - 0
runtime/triton_trtllm/token2wav_dit.py

@@ -0,0 +1,496 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Example Usage
+    CUDA_VISIBLE_DEVICES=0 \
+        python3 token2wav.py --enable-trt || exit 1
+"""
+import torch
+# from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
+from flashcosyvoice.modules.hifigan import HiFTGenerator
+from flashcosyvoice.utils.audio import mel_spectrogram
+import torchaudio.compliance.kaldi as kaldi
+import onnxruntime
+import s3tokenizer
+from torch.utils.data import DataLoader
+from datasets import load_dataset
+import torchaudio
+import os
+import logging
+import argparse
+import queue
+import time
+import numpy as np
+from hyperpyyaml import load_hyperpyyaml
+
+
+def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
+    """perform fade_in_out in tensor style
+    """
+    mel_overlap_len = int(window.shape[0] / 2)
+    fade_in_mel = fade_in_mel.clone()
+    fade_in_mel[..., :mel_overlap_len] = \
+        fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
+        fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
+    return fade_in_mel
+
+def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
+    import tensorrt as trt
+    logging.info("Converting onnx to trt...")
+    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+    logger = trt.Logger(trt.Logger.INFO)
+    builder = trt.Builder(logger)
+    network = builder.create_network(network_flags)
+    parser = trt.OnnxParser(network, logger)
+    config = builder.create_builder_config()
+    # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB
+    if dtype == torch.float16:
+        config.set_flag(trt.BuilderFlag.FP16)
+    elif dtype == torch.bfloat16:
+        config.set_flag(trt.BuilderFlag.BF16)
+    elif dtype == torch.float32:
+        config.set_flag(trt.BuilderFlag.FP32)
+    profile = builder.create_optimization_profile()
+    # load onnx model
+    with open(onnx_model, "rb") as f:
+        if not parser.parse(f.read()):
+            for error in range(parser.num_errors):
+                print(parser.get_error(error))
+            raise ValueError('failed to parse {}'.format(onnx_model))
+    # set input shapes
+    for i in range(len(trt_kwargs['input_names'])):
+        profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
+    if dtype == torch.float16:
+        tensor_dtype = trt.DataType.HALF
+    elif dtype == torch.bfloat16:
+        tensor_dtype = trt.DataType.BF16
+    elif dtype == torch.float32:
+        tensor_dtype = trt.DataType.FLOAT
+    else:
+        raise ValueError('invalid dtype {}'.format(dtype))
+    # set input and output data type
+    for i in range(network.num_inputs):
+        input_tensor = network.get_input(i)
+        input_tensor.dtype = tensor_dtype
+    for i in range(network.num_outputs):
+        output_tensor = network.get_output(i)
+        output_tensor.dtype = tensor_dtype
+    config.add_optimization_profile(profile)
+    engine_bytes = builder.build_serialized_network(network, config)
+    # save trt engine
+    with open(trt_model, "wb") as f:
+        f.write(engine_bytes)
+    logging.info("Succesfully convert onnx to trt...")
+
+class TrtContextWrapper:
+    def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        self.trt_engine = trt_engine
+        self.device = device
+        for _ in range(trt_concurrent):
+            trt_context = trt_engine.create_execution_context()
+            trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
+            assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
+            self.trt_context_pool.put([trt_context, trt_stream])
+        assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
+
+    def acquire_estimator(self):
+        return self.trt_context_pool.get(), self.trt_engine
+
+    def release_estimator(self, context, stream):
+        self.trt_context_pool.put([context, stream])
+
+class CosyVoice2_Token2Wav(torch.nn.Module):
+    def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
+        super().__init__()
+        self.device_id = device_id
+        self.device = f"cuda:{device_id}"
+        with open(f"{model_dir}/flow.yaml", "r") as f:
+            configs = load_hyperpyyaml(f)
+            self.flow = configs['flow']
+
+        self.dtype = dtype
+        self.flow.to(self.dtype)
+
+        self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
+        self.flow.to(self.device).eval()
+
+        self.hift = HiFTGenerator()
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
+                                                    providers=["CPUExecutionProvider"])
+        
+        self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
+
+        gpu="l20"
+        if enable_trt:
+            if streaming:
+                self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
+                                    f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
+                                    1,
+                                    self.dtype, streaming)
+            else:
+                self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
+                                    f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
+                                    1,
+                                    self.dtype)
+            self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
+                                f'{model_dir}/campplus.onnx',
+                                1,
+                                False)
+
+
+        self.streaming_flow_cache = {}
+        self.speaker_cache = {}
+
+        self.mel_cache_len = 8  # hard-coded, 160ms
+        self.source_cache_len = int(self.mel_cache_len * 480)   # 50hz mel -> 24kHz wave
+        self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
+
+        # hifigan cache for streaming tts
+        self.hift_cache_dict = {}
+
+    def forward_spk_embedding(self, spk_feat):
+        if isinstance(self.spk_model, onnxruntime.InferenceSession):
+            return self.spk_model.run(
+                None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
+            )[0].flatten().tolist()
+        else:
+            [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
+            # NOTE need to synchronize when switching stream
+            with torch.cuda.device(self.device_id):
+                torch.cuda.current_stream().synchronize()
+                spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
+                batch_size = spk_feat.size(0)
+
+                with stream:
+                    spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
+                    output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
+
+                    data_ptrs = [spk_feat.contiguous().data_ptr(),
+                                 output_tensor.contiguous().data_ptr()]
+                    for i, j in enumerate(data_ptrs):
+
+                        spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                    # run trt engine
+                    assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                    torch.cuda.current_stream().synchronize()
+                self.spk_model.release_estimator(spk_model, stream)
+
+            return output_tensor.cpu().numpy().flatten().tolist()
+
+    def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
+        if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
+            trt_kwargs = self.get_spk_trt_kwargs()
+            convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
+        import tensorrt as trt
+        with open(spk_model, 'rb') as f:
+            spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
+        self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_spk_trt_kwargs(self):
+        min_shape = [(1, 4, 80)]
+        opt_shape = [(1, 500, 80)]
+        max_shape = [(1, 3000, 80)]
+        input_names = ["input"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
+        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
+        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
+            opt_batch_size = 2
+            max_batch_size = 16
+            if streaming:
+                opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
+            trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
+            convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
+
+    def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
+        if streaming:
+            min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
+            opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)]
+            max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)]
+            input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
+        else:
+            min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
+            opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
+            max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
+            input_names = ["x", "mask", "mu", "cond", "t", "spks"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
+    def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
+        prompt_speech_tokens_list, prompt_speech_mels_list = [], []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            log_mel = s3tokenizer.log_mel_spectrogram(audio)  # [num_mels, T]
+            prompt_speech_mels_list.append(log_mel)
+        prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
+        prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
+            prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
+        )
+        for i in range(len(prompt_speech_tokens)):
+            speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
+            prompt_speech_tokens_list.append(speech_tokens_i)
+        return prompt_speech_tokens_list
+    
+    def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
+        spk_emb_for_flow = []
+        for audio in prompt_audios_list:
+            assert len(audio.shape) == 1
+            spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
+            spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
+            spk_emb = self.forward_spk_embedding(spk_feat)
+
+            spk_emb_for_flow.append(spk_emb)
+        spk_emb_for_flow = torch.tensor(spk_emb_for_flow)  
+        if self.dtype != torch.float32:
+            spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
+        return spk_emb_for_flow
+    
+    def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
+        prompt_mels_for_flow = []
+        prompt_mels_lens_for_flow = []
+        for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
+            assert len(audio.shape) == 1
+            audio = audio.unsqueeze(0)
+            if sample_rate != 24000:
+                audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
+            mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0)  # [T, num_mels]
+            mel_len = mel.shape[0]
+            prompt_mels_for_flow.append(mel)
+            prompt_mels_lens_for_flow.append(mel_len)
+        prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0)  # [B, T', num_mels=80]
+        prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
+        return prompt_mels_for_flow, prompt_mels_lens_for_flow
+    
+    def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
+        batch_size = prompt_mels_for_flow.shape[0]
+        flow_inputs = []
+        flow_inputs_lens = []
+        for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
+            flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
+            flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
+
+        flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
+        flow_inputs_lens = torch.tensor(flow_inputs_lens)
+
+        with torch.amp.autocast(self.device, dtype=torch.float16):
+            generated_mels, generated_mels_lens = self.flow.inference(
+                flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
+                prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
+            )
+
+        return generated_mels, generated_mels_lens
+
+    def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
+        batch_size = generated_mels.shape[0]
+        generated_wavs = []
+        for i in range(batch_size):
+            mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
+            wav, _ = self.hift(speech_feat=mel)
+            generated_wavs.append(wav)
+        return generated_wavs
+
+
+    @torch.inference_mode()
+    def forward(
+        self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
+    ):
+        # assert all item in prompt_audios_sample_rate is 16000
+        assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
+        
+
+        prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
+
+        generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
+
+        generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
+        
+        return generated_wavs
+
+    def prepare_prompt_audio(
+        self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
+    ):
+        # assert all item in prompt_audios_sample_rate is 16000
+        assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
+        
+
+        prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
+
+        prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
+
+        spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
+        
+        return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+
+
+    def get_prompt_audio_cache_for_streaming_tts(
+        self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
+    ):
+        assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
+        for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
+            prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
+        prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
+
+        cache = self.flow.setup_cache(
+            prompt_speech_tokens_tensor.to(self.device),
+            prompt_mels_for_flow.to(self.device),
+            spk_emb_for_flow.to(self.device),
+            n_timesteps=10
+        )
+
+        # cache dict's tensor batch dim is 1 for now
+        return cache
+
+
+    @torch.inference_mode()
+    def forward_streaming(
+        self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
+    ):
+
+        if speaker_id not in self.speaker_cache:
+            assert prompt_audio is not None, "prompt_audio is required for new speaker"
+            assert prompt_audio_sample_rate == 16000
+
+            prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
+
+            token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
+            prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
+            prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
+
+            cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
+            prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
+            
+            self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
+
+        if request_id not in self.streaming_flow_cache:
+            self.streaming_flow_cache[request_id] = self.speaker_cache[speaker_id]['cache_dict'].copy()
+            self.hift_cache_dict[request_id] = dict(
+            mel = torch.zeros(1, 80, 0, device='cuda'), 
+            source = torch.zeros(1, 1, 0, device='cuda'),
+            speech = torch.zeros(1, 0, device='cuda'),
+            )
+
+        current_request_cache = self.streaming_flow_cache[request_id]
+        prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
+        generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
+
+        chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
+            token=generated_speech_tokens,
+            spk=prompt_audio_dict['spk_emb_for_flow'].to(self.device),
+            cache=current_request_cache,
+            last_chunk=last_chunk,
+            n_timesteps=10,
+        )
+
+        self.streaming_flow_cache[request_id] = new_streaming_flow_cache
+
+        if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
+            self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
+                self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
+                self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
+            ], dim=4)
+
+
+
+        hift_cache_mel = self.hift_cache_dict[request_id]['mel']
+        hift_cache_source = self.hift_cache_dict[request_id]['source']
+        hift_cache_speech = self.hift_cache_dict[request_id]['speech']
+        mel = torch.concat([hift_cache_mel, chunk_mel], dim=2)
+
+        speech, source = self.hift(mel, hift_cache_source)
+
+        # overlap speech smooth
+        if hift_cache_speech.shape[-1] > 0:
+            speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
+
+        # update vocoder cache
+        self.hift_cache_dict[request_id] = dict(
+            mel = mel[..., -self.mel_cache_len:].clone().detach(),
+            source = source[:, :, -self.source_cache_len:].clone().detach(),
+            speech = speech[:, -self.source_cache_len:].clone().detach(),
+        )
+        if not last_chunk:
+            speech = speech[:, :-self.source_cache_len]
+
+        if last_chunk:
+            assert request_id in self.streaming_flow_cache
+            self.streaming_flow_cache.pop(request_id)
+            self.hift_cache_dict.pop(request_id)
+        
+        return speech
+
+def collate_fn(batch):
+    ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
+    for i, item in enumerate(batch):
+        generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
+        audio = torch.from_numpy(item['prompt_audio']['array']).float() 
+        prompt_audios_list.append(audio)
+        prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
+        ids.append(item['id'])
+
+    return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--enable-trt", action="store_true")
+    parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
+    parser.add_argument("--batch-size", type=int, default=1)
+    parser.add_argument("--output-dir", type=str, default="generated_wavs")
+    parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
+    parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
+    return parser.parse_args()
+
+if __name__ == "__main__":
+    args = get_args()
+    model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
+    # mkdir output_dir if not exists
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    dataset_name = "yuekai/seed_tts_cosy2"
+
+    dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
+
+
+    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
+    
+    
+    for epoch in range(args.warmup):
+        start_time = time.time()
+        
+        for batch in data_loader:
+            ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
+
+            generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
+            
+
+            for id, wav in zip(ids, generated_wavs):
+                torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
+        
+        end_time = time.time()
+        epoch_time = end_time - start_time
+        print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")