Browse Source

add streaming dit

yuekaiz 2 months ago
parent
commit
482464ea27

+ 10 - 4
runtime/triton_trtllm/client_grpc.py

@@ -209,7 +209,8 @@ def get_args():
         choices=[
             "f5_tts",
             "spark_tts",
-            "cosyvoice2"],
+            "cosyvoice2",
+            "cosyvoice2_dit"],
         help="triton model_repo module name to request",
     )
 
@@ -260,8 +261,8 @@ def get_args():
 
     parser.add_argument(
         "--use-spk2info-cache",
-        type=bool,
-        default=False,
+        type=str,
+        default="False",
         help="Use spk2info cache for reference audio.",
     )
 
@@ -490,6 +491,7 @@ async def send_streaming(
                     padding_duration=padding_duration,
                     use_spk2info_cache=use_spk2info_cache
                 )
+
                 request_id = str(uuid.uuid4())
                 user_data = UserData()
 
@@ -670,11 +672,15 @@ async def main():
             trust_remote_code=True,
         )
         manifest_item_list = []
+        tmp_audio_path="./asset_zero_shot_prompt.wav"
+        tmp_audio_text="希望你以后能够做的比我还好呦。"
         for i in range(len(dataset)):
             manifest_item_list.append(
                 {
                     "audio_filepath": dataset[i]["prompt_audio"],
                     "reference_text": dataset[i]["prompt_text"],
+                    # "audio_filepath": tmp_audio_path,
+                    # "reference_text": tmp_audio_text,
                     "target_audio_path": dataset[i]["id"],
                     "target_text": dataset[i]["target_text"],
                 }
@@ -686,7 +692,7 @@ async def main():
     manifest_item_list = split_data(manifest_item_list, num_tasks)
 
     os.makedirs(args.log_dir, exist_ok=True)
-
+    args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
     tasks = []
     start_time = time.time()
     for i in range(num_tasks):

+ 21 - 38
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py

@@ -227,12 +227,11 @@ class TritonPythonModel:
 
     def forward_token2wav(
             self,
+            index: int,
             target_speech_tokens: torch.Tensor,
             request_id: str,
-            prompt_speech_tokens: torch.Tensor = None,
-            prompt_speech_feat: torch.Tensor = None,
-            prompt_spk_embedding: torch.Tensor = None,
-            token_offset: int = None,
+            reference_wav: object,
+            reference_wav_len: object,
             finalize: bool = None) -> torch.Tensor:
         """Forward pass through the vocoder component.
 
@@ -246,29 +245,16 @@ class TritonPythonModel:
             Generated waveform tensor
         """
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
-
-        inputs_tensor = [target_speech_tokens_tensor]
-
-        if token_offset is not None:
-            assert finalize is not None
-            token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
-            finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
-            inputs_tensor.append(token_offset_tensor)
-            inputs_tensor.append(finalize_tensor)
-
-        if prompt_spk_embedding is not None:
-            assert prompt_speech_feat is not None
-            prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
-            prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
-            prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
-            inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
+        finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
+        inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
 
         # Create and execute inference request
         inference_request = pb_utils.InferenceRequest(
-            model_name='token2wav',
+            model_name='token2wav_dit',
             requested_output_names=['waveform'],
             inputs=inputs_tensor,
             request_id=request_id,
+            parameters={"priority": index+1},
         )
 
         inference_response = inference_request.exec()
@@ -346,8 +332,15 @@ class TritonPythonModel:
 
                 reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
                 reference_text = reference_text[0][0].decode('utf-8')
-                prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+                # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+
+                # reference_text = self.default_spk_info["prompt_text"]
+                # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
+                # prompt_speech_feat = None
+                # prompt_spk_embedding = None
+
             else:
+                assert False, "wav is None"
                 # using pre-cached reference text
                 reference_text = self.default_spk_info["prompt_text"]
                 prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
@@ -391,12 +384,12 @@ class TritonPythonModel:
                         break
 
                     if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
-                        this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
+                        this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
                         this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
 
                         sub_tts_speech = self.forward_token2wav(
-                            this_tts_speech_token, request_id, prompt_speech_tokens,
-                            prompt_speech_feat, prompt_spk_embedding, token_offset, False
+                            chunk_index,
+                            this_tts_speech_token, request_id, wav, wav_len, False
                         )
 
                         audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
@@ -429,8 +422,8 @@ class TritonPythonModel:
                     else:
                         time.sleep(0.02)
 
-                this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
-                sub_tts_speech = self.forward_token2wav(this_tts_speech_token, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
+                this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
+                sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 response_sender.send(inference_response)
@@ -439,17 +432,7 @@ class TritonPythonModel:
                 response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                 self.logger.log_info("send tritonserver_response_complete_final to end")
             else:
-                generated_ids = next(generated_ids_iter)
-                generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
-                if generated_ids is None or len(generated_ids) == 0:
-                    raise pb_utils.TritonModelException("Generated IDs is None or empty")
-
-                audio = self.forward_token2wav(generated_ids, request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
-
-                # Prepare response
-                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
-                inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
-                responses.append(inference_response)
+                raise NotImplementedError("Decoupled mode is not supported")
 
         if not self.decoupled:
             return responses

+ 438 - 0
runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py

@@ -0,0 +1,438 @@
+# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import json
+import math
+import os
+import re
+import time
+from typing import Dict, List, Tuple, Optional, Union
+import asyncio
+import httpx
+
+import numpy as np
+import torch
+from torch.utils.dlpack import from_dlpack, to_dlpack
+import triton_python_backend_utils as pb_utils
+from transformers import AutoTokenizer
+
+import torchaudio
+
+
+from matcha.utils.audio import mel_spectrogram
+
+ORIGINAL_VOCAB_SIZE = 151663
+torch.set_num_threads(1)
+
+
+def parse_speech_token_string(response_text: str) -> List[int]:
+    """
+    Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
+    """
+    speech_tokens = response_text.strip().split('><')
+    if len(speech_tokens) > 1:
+        # Add back the missing '<' and '>' for proper parsing
+        speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
+        speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
+
+    speech_ids = []
+    for token_str in speech_tokens:
+        match = re.match(r'<\|s_(\d+)\|>', token_str)
+        if match:
+            speech_ids.append(int(match.group(1)))
+    return speech_ids
+
+
+class TritonPythonModel:
+    """Triton Python model for Spark TTS.
+
+    This model orchestrates the end-to-end TTS pipeline by coordinating
+    between audio tokenizer, LLM, and vocoder components.
+    """
+
+    def initialize(self, args):
+        """Initialize the model.
+
+        Args:
+            args: Dictionary containing model configuration
+        """
+        self.logger = pb_utils.Logger
+        # Parse model parameters
+        self.model_config = json.loads(args['model_config'])
+        parameters = self.model_config['parameters']
+        model_params = {k: v["string_value"] for k, v in parameters.items()}
+        self.logger.log_info(f"model_params:{model_params}")
+        self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
+        self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
+
+        # Initialize tokenizer
+        llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
+        self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
+        self.prompt_template = "<|sos|>{input_text}<|task_id|>"
+        self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
+
+        self.device = torch.device("cuda")
+        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
+
+        self.token_frame_rate = 25
+        self.flow_pre_lookahead_len = 3
+        self.token_hop_len = 15
+
+        spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
+        if not os.path.exists(spk_info_path):
+            raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
+        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
+        # self.default_spk_info = spk_info["001"]
+
+    def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
+        """Converts a tensor or list of speech token IDs to a string representation."""
+        if isinstance(speech_tokens, torch.Tensor):
+            # Ensure tensor is on CPU and flattened
+            speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
+
+        speech_id_str = ""
+        for token_id in speech_tokens:
+            # Convert token ID back to the speech number N
+            token_num = token_id - ORIGINAL_VOCAB_SIZE
+            speech_id_str += f"<|s_{token_num}|>"
+        return speech_id_str
+
+    async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
+        """
+        Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
+        """
+        full_text = f"{reference_text}{target_text}"
+        prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
+
+        chat = [
+            {"role": "user", "content": full_text},
+            {"role": "assistant", "content": prompt_speech_tokens_str}
+        ]
+        print(chat)
+
+        payload = {
+            "model": "trt_engines_bfloat16",
+            "messages": chat,
+            "max_tokens": 750,
+            "temperature": 0.8,
+            "top_p": 0.95,
+            "top_k": 50,
+            "repetition_penalty": 1.1,
+            "stop": ["<|eos1|>", "<|eos|>"],
+            "stream": True,
+        }
+
+        api_base = "http://localhost:8000/v1/chat/completions"
+
+        buffer = ""
+        async with httpx.AsyncClient() as client:
+            async with client.stream("POST", api_base, json=payload, timeout=None) as response:
+                response.raise_for_status()
+                async for line in response.aiter_lines():
+                    if line.startswith("data: "):
+                        line_data = line[len("data: "):].strip()
+                        if line_data == "[DONE]":
+                            break
+                        try:
+                            json_data = json.loads(line_data)
+                            content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
+                            if content:
+                                buffer += content
+                                while True:
+                                    match = re.search(r"<\|s_(\d+)\|>", buffer)
+                                    if not match:
+                                        break
+
+                                    token_num = int(match.group(1))
+                                    final_id = token_num + ORIGINAL_VOCAB_SIZE
+                                    yield final_id
+                                    buffer = buffer[match.end():]
+                        except json.JSONDecodeError:
+                            self.logger.log_info(f"Skipping non-JSON line: {line_data}")
+                            continue
+
+        # Process any remaining complete tokens in the buffer after the stream ends
+        while True:
+            match = re.search(r"<\|s_(\d+)\|>", buffer)
+            if not match:
+                break
+            token_num = int(match.group(1))
+            final_id = token_num + ORIGINAL_VOCAB_SIZE
+            yield final_id
+            buffer = buffer[match.end():]
+
+
+    def forward_audio_tokenizer(self, wav, wav_len):
+        """Forward pass through the audio tokenizer component.
+
+        Args:
+            wav: Input waveform tensor
+            wav_len: Waveform length tensor
+
+        Returns:
+            Tuple of global and semantic tokens
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='audio_tokenizer',
+            requested_output_names=['prompt_speech_tokens'],
+            inputs=[wav, wav_len]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
+        prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
+
+        return prompt_speech_tokens
+
+    def forward_speaker_embedding(self, wav):
+        """Forward pass through the speaker embedding component.
+
+        Args:
+            wav: Input waveform tensor
+
+        Returns:
+            Prompt speaker embedding tensor
+        """
+        inference_request = pb_utils.InferenceRequest(
+            model_name='speaker_embedding',
+            requested_output_names=['prompt_spk_embedding'],
+            inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output tensors
+        prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
+        prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
+
+        return prompt_spk_embedding
+
+    def forward_token2wav(
+            self,
+            index: int,
+            target_speech_tokens: torch.Tensor,
+            request_id: str,
+            reference_wav: object,
+            reference_wav_len: object,
+            finalize: bool = None) -> torch.Tensor:
+        """Forward pass through the vocoder component.
+
+        Args:
+            prompt_speech_tokens: Prompt speech tokens tensor
+            prompt_speech_feat: Prompt speech feat tensor
+            prompt_spk_embedding: Prompt spk embedding tensor
+            target_speech_tokens: Target speech tokens tensor
+
+        Returns:
+            Generated waveform tensor
+        """
+        target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
+        finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
+        inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
+
+        # Create and execute inference request
+        inference_request = pb_utils.InferenceRequest(
+            model_name='token2wav_dit',
+            requested_output_names=['waveform'],
+            inputs=inputs_tensor,
+            request_id=request_id,
+            parameters={"priority": index+1},
+        )
+
+        inference_response = inference_request.exec()
+        if inference_response.has_error():
+            raise pb_utils.TritonModelException(inference_response.error().message())
+
+        # Extract and convert output waveform
+        waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
+        waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
+
+        return waveform
+
+    def _extract_speech_feat(self, speech):
+        speech_feat = mel_spectrogram(
+            speech,
+            n_fft=1920,
+            num_mels=80,
+            sampling_rate=24000,
+            hop_size=480,
+            win_size=1920,
+            fmin=0,
+            fmax=8000).squeeze(
+            dim=0).transpose(
+            0,
+            1).to(
+                self.device)
+        speech_feat = speech_feat.unsqueeze(dim=0)
+        return speech_feat
+
+    async def _process_request(self, request):
+        request_id = request.request_id()
+        # Extract input tensors
+        wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
+
+        # Process reference audio through audio tokenizer
+        if wav is not None:
+            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+            prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
+            prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
+
+            wav_tensor = wav.as_numpy()
+            wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
+            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
+            speech_feat = self._extract_speech_feat(prompt_speech_resample)
+            token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
+            prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
+            prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
+
+            reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+            reference_text = reference_text[0][0].decode('utf-8')
+            # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+
+            # reference_text = self.default_spk_info["prompt_text"]
+            # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
+            # prompt_speech_feat = None
+            # prompt_spk_embedding = None
+
+        else:
+            # using pre-cached reference text
+            assert False, "using pre-cached reference text is not supported"
+            reference_text = self.default_spk_info["prompt_text"]
+            prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
+            prompt_speech_feat = None
+            prompt_spk_embedding = None
+
+        target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
+        target_text = target_text[0][0].decode('utf-8')
+
+        if self.decoupled:
+            response_sender = request.get_response_sender()
+
+            semantic_token_ids_arr = []
+            token_offset, chunk_index = 0, 0
+            start_time = time.time()
+            this_token_hop_len = self.token_hop_len
+
+            async for generated_ids in self.forward_llm_async(
+                target_text=target_text,
+                reference_text=reference_text,
+                prompt_speech_tokens=prompt_speech_tokens,
+            ):
+                if not generated_ids:
+                    break
+                semantic_token_ids_arr.append(generated_ids)
+                
+                while True:
+                    pending_num = len(semantic_token_ids_arr) - token_offset
+                    if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
+                        this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
+                        this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
+
+                        sub_tts_speech = self.forward_token2wav(
+                            chunk_index,
+                            this_tts_speech_token, request_id, wav, wav_len, False
+                        )
+
+                        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+                        inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+                        response_sender.send(inference_response)
+
+                        token_offset += this_token_hop_len
+                        self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
+
+                        if self.dynamic_chunk_strategy == "exponential":
+                            this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
+                        elif self.dynamic_chunk_strategy == "time_based":
+                            # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
+                            cost_time = time.time() - start_time
+                            duration = token_offset / self.token_frame_rate
+                            if chunk_index > 0 and cost_time > 0:
+                                avg_chunk_processing_time = cost_time / (chunk_index + 1)
+                                if avg_chunk_processing_time > 0:
+                                    multiples = (duration - cost_time) / avg_chunk_processing_time
+                                    self.logger.log_info(f"multiples: {multiples}")
+                                    next_pending_num = len(semantic_token_ids_arr) - token_offset
+                                    if multiples > 4:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
+                                    elif multiples > 2:
+                                        this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
+                                    else:
+                                        this_token_hop_len = self.token_hop_len
+                                    this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
+                        chunk_index += 1
+                    else:
+                        break
+            
+            this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
+            sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
+            audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+            inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+            response_sender.send(inference_response)
+
+            ## debug
+            ## save semantic_token_ids_arr and reference_text, target_text to a single json file
+            # save into a torch .pt
+            # for i, item in enumerate(semantic_token_ids_arr):
+            #     semantic_token_ids_arr[i] = item - ORIGINAL_VOCAB_SIZE
+            # import json
+            # data = {
+            #     "semantic_token_ids_arr": semantic_token_ids_arr,
+            #     "reference_text": reference_text,
+            #     "target_text": target_text
+            # }
+            # with open(f"semantic_token_ids_arr_debug_{request_id}.pt", "wb") as f:
+            #     torch.save(data, f)
+            # with open(f"semantic_token_ids_arr_debug_{request_id}.json", "w") as f:
+            #     json.dump(data, f)
+            
+            # ##
+
+            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+            self.logger.log_info("send tritonserver_response_complete_final to end")
+        else:
+            raise NotImplementedError("Decoupled mode is not supported")
+
+    async def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated audio
+        """
+        tasks = [
+            asyncio.create_task(self._process_request(request))
+            for request in requests
+        ]
+        await asyncio.gather(*tasks)
+        return None

+ 1 - 1
runtime/triton_trtllm/model_repo/cosyvoice2_dit/config.pbtxt

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-name: "cosyvoice2"
+name: "cosyvoice2_dit"
 backend: "python"
 max_batch_size: ${triton_max_batch_size}
 dynamic_batching {

+ 36 - 173
runtime/triton_trtllm/model_repo/token2wav_dit/1/model.py

@@ -42,6 +42,8 @@ from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vl
 from cosyvoice.utils.common import TrtContextWrapper
 from collections import defaultdict
 import numpy as np
+from .token2wav_dit import CosyVoice2_Token2Wav
+import hashlib
 
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 logger = logging.getLogger(__name__)
@@ -49,117 +51,19 @@ logger = logging.getLogger(__name__)
 ORIGINAL_VOCAB_SIZE = 151663
 torch.set_num_threads(1)
 
+def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
+    """
+    Generates a unique ID for a torch.Tensor.
+    Tensors with the same elements and properties will have the same ID.
+    """
+    # Convert tensor to a byte string
+    tensor_bytes = tensor.numpy().tobytes()
 
-class CosyVoice2:
-
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
-
-        self.model_dir = model_dir
-        self.fp16 = fp16
-
-        hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
-        if not os.path.exists(hyper_yaml_path):
-            raise ValueError('{} not found!'.format(hyper_yaml_path))
-        with open(hyper_yaml_path, 'r') as f:
-            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
-        self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
-        self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
-        if load_jit:
-            self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
-        if load_trt:
-            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
-                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
-                                trt_concurrent,
-                                self.fp16)
-
-
-class CosyVoice2Model:
-
-    def __init__(self,
-                 flow: torch.nn.Module,
-                 hift: torch.nn.Module,
-                 fp16: bool = False,
-                 device: str = 'cuda'):
-        self.device = device
-        self.flow = flow
-        self.hift = hift
-        self.fp16 = fp16
-        if self.fp16 is True:
-            self.flow.half()
-
-        # streaming tts config
-        self.token_hop_len = 25
-        self.mel_cache_len = 8
-        self.source_cache_len = int(self.mel_cache_len * 480)
-        self.speech_window = np.hamming(2 * self.source_cache_len)
-        self.hift_cache_dict = defaultdict(lambda: None)
-
-    def load_jit(self, flow_encoder_model):
-        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
-        self.flow.encoder = flow_encoder
-
-    def load(self, flow_model, hift_model):
-        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
-        self.flow.to(self.device).eval()
-        # in case hift_model is a hifigan model
-        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
-        self.hift.load_state_dict(hift_state_dict, strict=True)
-        self.hift.to(self.device).eval()
-
-    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
-        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
-        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
-            convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
-        del self.flow.decoder.estimator
-        import tensorrt as trt
-        with open(flow_decoder_estimator_model, 'rb') as f:
-            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
-        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
-        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
-
-    def get_trt_kwargs(self):
-        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
-        opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
-        max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
-        input_names = ["x", "mask", "mu", "cond"]
-        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
-
-    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
-        with torch.cuda.amp.autocast(self.fp16):
-            tts_mel, _ = self.flow.inference(token=token.to(self.device),
-                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
-                                             prompt_token=prompt_token.to(self.device),
-                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
-                                             prompt_feat=prompt_feat.to(self.device),
-                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
-                                             embedding=embedding.to(self.device),
-                                             streaming=stream,
-                                             finalize=finalize)
-        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
-        # append hift cache
-        if self.hift_cache_dict[uuid] is not None:
-            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
-            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
-        else:
-            hift_cache_source = torch.zeros(1, 1, 0)
-        # keep overlap mel and hift cache
-        if finalize is False:
-            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
-            if self.hift_cache_dict[uuid] is not None:
-                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
-            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
-                                          'source': tts_source[:, :, -self.source_cache_len:],
-                                          'speech': tts_speech[:, -self.source_cache_len:]}
-            tts_speech = tts_speech[:, :-self.source_cache_len]
-        else:
-            if speed != 1.0:
-                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
-                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
-            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
-            if self.hift_cache_dict[uuid] is not None:
-                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
-        return tts_speech
-
+    # Create a SHA-256 hash of the byte string
+    hasher = hashlib.sha256()
+    hasher.update(tensor_bytes)
+    
+    return hasher.hexdigest()
 
 class TritonPythonModel:
     """Triton Python model for vocoder.
@@ -183,16 +87,10 @@ class TritonPythonModel:
         self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
 
-        self.token2wav_model = CosyVoice2(
-            model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
+        # FIXME: device id settings
+        self.token2wav_model = CosyVoice2_Token2Wav(
+            model_dir, enable_trt=True, streaming=True
         )
-
-        spk_info_path = os.path.join(model_dir, "spk2info.pt")
-        if not os.path.exists(spk_info_path):
-            raise ValueError(f"spk2info.pt not found in {model_dir}")
-        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
-        self.default_spk_info = spk_info["001"]
-
         logger.info("Token2Wav initialized successfully")
 
     def execute(self, requests):
@@ -208,66 +106,31 @@ class TritonPythonModel:
         # Process each request in batch
         for request in requests:
             target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
-            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
-
-            prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
-            if prompt_speech_tokens_tensor is not None:
-                prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
-                prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
-                prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
-                prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
-                prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
-                prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
-                prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
-            else:
-                prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
-                prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
-                prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
-
+            target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)#.to(self.device)
             # shift the speech tokens according to the original vocab size
             target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
+            target_speech_tokens = target_speech_tokens.squeeze().tolist()
 
             # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
-            token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
-            if token_offset is not None:
-                token_offset = token_offset.as_numpy().item()
-                finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
-                if not finalize:
-                    stream = True
-                else:
-                    stream = False
-                request_id = request.request_id()
-                audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
-                                                                 prompt_token=prompt_speech_tokens,
-                                                                 prompt_feat=prompt_speech_feat,
-                                                                 embedding=prompt_spk_embedding,
-                                                                 token_offset=token_offset,
-                                                                 uuid=request_id,
-                                                                 stream=stream,
-                                                                 finalize=finalize)
-                if finalize:
-                    self.token2wav_model.model.hift_cache_dict.pop(request_id)
+           
+            finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
+                
+            request_id = request.request_id()
+               
+
+            wav_array = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav").as_numpy()
+            wav_len = pb_utils.get_input_tensor_by_name(
+                request, "reference_wav_len").as_numpy().item()
+
+            wav_array = torch.from_numpy(wav_array)
+            # Prepare inputs
+            wav = wav_array[:, :wav_len].squeeze(0)
 
-            else:
-                tts_mel, _ = self.token2wav_model.model.flow.inference(
-                    token=target_speech_tokens,
-                    token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
-                        self.device
-                    ),
-                    prompt_token=prompt_speech_tokens,
-                    prompt_token_len=torch.tensor(
-                        [prompt_speech_tokens.shape[1]], dtype=torch.int32
-                    ).to(self.device),
-                    prompt_feat=prompt_speech_feat,
-                    prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
-                    embedding=prompt_spk_embedding,
-                    streaming=False,
-                    finalize=True,
-                )
+            spk_id = get_spk_id_from_prompt_audio(wav)
+            # wav = wav.to(self.device)
 
-                audio_hat, _ = self.token2wav_model.model.hift.inference(
-                    speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
-                )
+            audio_hat = self.token2wav_model.forward_streaming(target_speech_tokens, finalize, request_id=request_id, speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000)
 
             generated_wave = audio_hat.squeeze(0).cpu().numpy()
 

+ 48 - 10
runtime/triton_trtllm/token2wav_dit.py → runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py

@@ -362,17 +362,17 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             spk_emb_for_flow.to(self.device),
             n_timesteps=10
         )
+        new_cache = {k: v.clone() for k, v in cache.items()}
         # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
-        cache['estimator_att_cache'] = cache['estimator_att_cache'].clone()
-        cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone()
-        return cache
+        return new_cache
 
 
     @torch.inference_mode()
     def forward_streaming(
         self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
-    ):
+    ):  
         if speaker_id not in self.speaker_cache:
+        # if 1:
             assert prompt_audio is not None, "prompt_audio is required for new speaker"
             assert prompt_audio_sample_rate == 16000
 
@@ -382,10 +382,21 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
             prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
 
-            cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
             prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
+
+        # if speaker_id not in self.speaker_cache:
+        # if 1:
             
+            cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
             self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
+            print(f"speaker_id {speaker_id} added to cache")
+
+            # get a clone of cache dict ['estimator_att_cache'] and later check if it would be change 
+        att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache'].clone()
+        cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache'].clone()
+        conformer_cnn_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache'].clone()
+        conformer_att_cache_clone = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache'].clone()
+    
 
         if request_id not in self.streaming_flow_cache:
             self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
@@ -409,6 +420,33 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             n_timesteps=10,
         )
 
+        # get the original att_cache
+        original_att_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_att_cache']
+        original_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['estimator_cnn_cache']
+        original_conformer_cnn_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_cnn_cache']
+        original_conformer_att_cache = self.speaker_cache[speaker_id]['cache_dict']['conformer_att_cache']
+        if not torch.allclose(original_att_cache, att_cache_clone):
+            print("att_cache changed")
+            # print the last 10 elements of original_att_cache and att_cache_clone
+            print(original_att_cache[:, :, :, -10:])
+            print(att_cache_clone[:, :, :, -10:])
+            breakpoint()
+        if not torch.allclose(original_cnn_cache, cnn_cache_clone):
+            print("cnn_cache changed")
+            print(original_cnn_cache[..., -10:])
+            print(cnn_cache_clone[..., -10:])
+            breakpoint()
+        if not torch.allclose(original_conformer_cnn_cache, conformer_cnn_cache_clone):
+            print("conformer_cnn_cache changed")
+            print(original_conformer_cnn_cache[..., -10:])
+            print(conformer_cnn_cache_clone[..., -10:])
+            breakpoint()
+        if not torch.allclose(original_conformer_att_cache, conformer_att_cache_clone):
+            print("conformer_att_cache changed")
+            print(original_conformer_att_cache[..., -10:])
+            print(conformer_att_cache_clone[..., -10:])
+            breakpoint()
+
         self.streaming_flow_cache[request_id] = new_streaming_flow_cache
 
 
@@ -420,10 +458,10 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
 
 
-        hift_cache_mel = self.hift_cache_dict[request_id]['mel']
-        hift_cache_source = self.hift_cache_dict[request_id]['source']
-        hift_cache_speech = self.hift_cache_dict[request_id]['speech']
-        mel = torch.concat([hift_cache_mel, chunk_mel], dim=2)
+        hift_cache_mel = self.hift_cache_dict[request_id]['mel'].clone()
+        hift_cache_source = self.hift_cache_dict[request_id]['source'].clone()
+        hift_cache_speech = self.hift_cache_dict[request_id]['speech'].clone()
+        mel = torch.concat([hift_cache_mel, chunk_mel], dim=2).clone()
 
         speech, source = self.hift(mel, hift_cache_source)
 
@@ -444,7 +482,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
             assert request_id in self.streaming_flow_cache
             self.streaming_flow_cache.pop(request_id)
             self.hift_cache_dict.pop(request_id)
-        
+        # breakpoint()
         return speech
 
 def collate_fn(batch):

+ 7 - 20
runtime/triton_trtllm/model_repo/token2wav_dit/config.pbtxt

@@ -12,11 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-name: "token2wav"
+name: "token2wav_dit"
 backend: "python"
 max_batch_size: ${triton_max_batch_size}
 dynamic_batching {
     max_queue_delay_microseconds: ${max_queue_delay_microseconds}
+    priority_levels: 10
+    default_priority_level: 10
 }
 parameters [
   {
@@ -32,29 +34,14 @@ input [
     dims: [-1]
   },
   {
-    name: "prompt_speech_tokens"
-    data_type: TYPE_INT32
-    dims: [-1]
-    optional: true
-  },
-  {
-    name: "prompt_speech_feat"
-    data_type: TYPE_FP16
-    dims: [-1, 80]
-    optional: true
-  },
-  {
-    name: "prompt_spk_embedding"
-    data_type: TYPE_FP16
+    name: "reference_wav"
+    data_type: TYPE_FP32
     dims: [-1]
-    optional: true
   },
   {
-    name: "token_offset"
+    name: "reference_wav_len"
     data_type: TYPE_INT32
-    dims: [ 1 ]
-    reshape: { shape: [ ] }
-    optional: true
+    dims: [1]
   },
   {
     name: "finalize"

+ 94 - 5
runtime/triton_trtllm/offline_inference.py

@@ -43,6 +43,9 @@ import soundfile as sf
 import s3tokenizer
 from functools import partial
 import time
+import requests
+import asyncio
+import httpx
 
 from token2wav import CosyVoice2_Token2Wav
 
@@ -53,6 +56,32 @@ except RuntimeError:
     pass
 
 
+async def send_request_async(client, url, payload):
+    response = await client.post(url, json=payload, timeout=None)
+    response.raise_for_status()
+    response_json = response.json()
+    return response_json['choices'][0]['message']['content']
+
+
+async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
+    async with httpx.AsyncClient() as client:
+        tasks = []
+        for chat in chats:
+            payload = {
+                "model": model_name,
+                "messages": chat,
+                "max_tokens": 2048,
+                "temperature": temperature,
+                "top_p": top_p,
+                "top_k": top_k,
+                "repetition_penalty": 1.1,
+                "stop": ["<|eos1|>", "<|eos|>"],
+                "stream": False,
+            }
+            tasks.append(send_request_async(client, api_base, payload))
+        return await asyncio.gather(*tasks)
+
+
 def extract_speech_ids(speech_tokens_str):
     """Extract speech IDs from token strings like <|s_23456|>"""
     speech_ids = []
@@ -149,7 +178,7 @@ def get_args():
         "--backend",
         type=str,
         default="hf",
-        choices=["hf", "trtllm", "vllm"],
+        choices=["hf", "trtllm", "vllm", "trtllm-serve"],
         help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
     )
     parser.add_argument(
@@ -164,6 +193,18 @@ def get_args():
         default=0.6,
         help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
     )
+    parser.add_argument(
+        "--openai-api-base",
+        type=str,
+        default="http://localhost:8000/v1/chat/completions",
+        help="OpenAI API base URL (for trtllm-serve backend)",
+    )
+    parser.add_argument(
+        "--openai-model-name",
+        type=str,
+        default="trt_engines_bfloat16",
+        help="Model name to use with OpenAI API (for trtllm-serve backend)",
+    )
     args = parser.parse_args()
     return args
 
@@ -180,6 +221,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
     input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
     prompt_text_after_apply_template_list = []
     mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
+    chat_list = []
     for _, item in enumerate(batch):
         audio_processing_start_time = time.time()
         prompt_text, target_text = (
@@ -237,6 +279,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
             {"role": "user", "content": full_text_list[i]},
             {"role": "assistant", "content": prompt_audio_cosy2_id_str}
         ]
+        chat_list.append(chat)
 
         assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
 
@@ -265,6 +308,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
         "audio_processing_time": total_audio_processing_time,
         "speech_tokenization_time": total_speech_tokenization_time,
         "text_tokenization_time": total_text_tokenization_time,
+        "chat_list": chat_list
     }
 
 
@@ -318,6 +362,9 @@ def main(args):
     elif args.backend == "vllm":
         model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
         runner = None
+    elif args.backend == "trtllm-serve":
+        model = None
+        runner = None
     else:
         raise ValueError(f"Unsupported backend: {args.backend}")
 
@@ -452,6 +499,35 @@ def main(args):
                     print(outputs)
                     for j, output in enumerate(outputs):
                         outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
+                elif args.backend == "trtllm-serve":
+                    if args.batch_size > 1:
+                        outputs = asyncio.run(send_batch_requests_async(
+                            args.openai_api_base,
+                            args.openai_model_name,
+                            batch["chat_list"],
+                            args.temperature,
+                            args.top_p,
+                            args.top_k,
+                        ))
+                    else:
+                        outputs = []
+                        for i, chat in enumerate(batch["chat_list"]):
+                            payload = {
+                                "model": args.openai_model_name,
+                                "messages": chat,
+                                "max_tokens": 2048,
+                                "temperature": args.temperature,
+                                "top_p": args.top_p,
+                                "top_k": args.top_k,
+                                "repetition_penalty": 1.1,
+                                "stop": ["<|eos1|>", "<|eos|>"],
+                                "stream": False,
+                            }
+                            response = requests.post(args.openai_api_base, json=payload)
+                            response.raise_for_status()
+                            response_json = response.json()
+                            generated_content = response_json['choices'][0]['message']['content']
+                            outputs.append(generated_content)
 
                 llm_end_time = time.time()
                 total_llm_time += (llm_end_time - llm_start_time)
@@ -459,10 +535,21 @@ def main(args):
                 items_for_token_2wav = []
                 for i in range(len(batch["ids"])):
                     llm_post_processing_start_time = time.time()
-                    input_length = len(batch["input_ids"][i])
-                    generated_ids = outputs[i][input_length:]
-                    speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
-                    speech_ids = extract_speech_ids(speech_tokens_str)
+                    if args.backend == "trtllm-serve":
+                        speech_tokens_str = outputs[i].strip().split('><')
+                        if len(speech_tokens_str) > 1:
+                            speech_tokens_str = [
+                                t if t.startswith('<') else '<' + t for t in speech_tokens_str
+                            ]
+                            speech_tokens_str = [
+                                t if t.endswith('>') else t + '>' for t in speech_tokens_str
+                            ]
+                        speech_ids = extract_speech_ids(speech_tokens_str)
+                    else:
+                        input_length = len(batch["input_ids"][i])
+                        generated_ids = outputs[i][input_length:]
+                        speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+                        speech_ids = extract_speech_ids(speech_tokens_str)
                     print(i, speech_ids)
                     if len(speech_ids) == 0:
                         print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
@@ -558,6 +645,8 @@ if __name__ == "__main__":
         from tensorrt_llm.runtime import ModelRunnerCpp
     elif args.backend == "hf":
         from transformers import AutoModelForCausalLM
+    elif args.backend == "trtllm-serve":
+        pass
     else:
         raise ValueError(f"Unsupported backend: {args.backend}")
     main(args)

+ 80 - 18
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
-export CUDA_VISIBLE_DEVICES=0
+export CUDA_VISIBLE_DEVICES=1
 cosyvoice_path=/workspace/CosyVoice
 cosyvoice_path=/workspace_yuekai/tts/CosyVoice
 stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
@@ -16,7 +16,7 @@ trt_dtype=bfloat16
 trt_weights_dir=./trt_weights_${trt_dtype}
 trt_engines_dir=./trt_engines_${trt_dtype}
 
-model_repo=./model_repo_cosyvoice2
+model_repo=./model_repo_cosyvoice2_dit
 
 use_spk2info_cache=False
 
@@ -58,40 +58,78 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
                     --engine_dir=$trt_engines_dir  || exit 1
 fi
 
+
+# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+#     echo "Creating model repository"
+#     rm -rf $model_repo
+#     mkdir -p $model_repo
+#     cosyvoice2_dir="cosyvoice2_dit"
+#     token2wav_dir="token2wav_dit"
+
+#     cp -r ./model_repo/${cosyvoice2_dir} $model_repo
+#     cp -r ./model_repo/tensorrt_llm $model_repo
+#     cp -r ./model_repo/${token2wav_dir} $model_repo
+#     #if [ $use_spk2info_cache == "False" ]; then
+#         cp -r ./model_repo/audio_tokenizer $model_repo
+#         cp -r ./model_repo/speaker_embedding $model_repo
+#     #fi
+
+#     ENGINE_PATH=$trt_engines_dir
+#     MAX_QUEUE_DELAY_MICROSECONDS=0
+#     MODEL_DIR=$model_scope_model_local_dir
+#     LLM_TOKENIZER_DIR=$huggingface_model_local_dir
+#     BLS_INSTANCE_NUM=1
+#     TRITON_MAX_BATCH_SIZE=16
+#     DECOUPLED_MODE=True # True for streaming, False for offline
+#     STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
+
+#     python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+#     python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+#     python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
+#     #if [ $use_spk2info_cache == "False" ]; then
+#         python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+#         python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+#     #fi
+# fi
+
 if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
-    echo "Creating model repository"
+    echo "Creating model repository async mode"
     rm -rf $model_repo
     mkdir -p $model_repo
-    cosyvoice2_dir="cosyvoice2"
+    cosyvoice2_dir="cosyvoice2_dit"
+    token2wav_dir="token2wav_dit"
 
     cp -r ./model_repo/${cosyvoice2_dir} $model_repo
     cp -r ./model_repo/tensorrt_llm $model_repo
-    cp -r ./model_repo/token2wav $model_repo
-    if [ $use_spk2info_cache == "False" ]; then
+    cp -r ./model_repo/${token2wav_dir} $model_repo
+    #if [ $use_spk2info_cache == "False" ]; then
         cp -r ./model_repo/audio_tokenizer $model_repo
         cp -r ./model_repo/speaker_embedding $model_repo
-    fi
+    #fi
 
     ENGINE_PATH=$trt_engines_dir
     MAX_QUEUE_DELAY_MICROSECONDS=0
     MODEL_DIR=$model_scope_model_local_dir
     LLM_TOKENIZER_DIR=$huggingface_model_local_dir
     BLS_INSTANCE_NUM=4
-    TRITON_MAX_BATCH_SIZE=16
+    TRITON_MAX_BATCH_SIZE=32
     DECOUPLED_MODE=True # True for streaming, False for offline
+    STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
 
-    python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
     python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
     python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
-    if [ $use_spk2info_cache == "False" ]; then
+    #if [ $use_spk2info_cache == "False" ]; then
         python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
         python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-    fi
+    #fi
+    rm -rf $model_repo/tensorrt_llm
+    # mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4
 fi
 
 if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
    echo "Starting Triton server"
-   tritonserver --model-repository $model_repo
+   tritonserver --model-repository $model_repo --http-port 18000
 fi
 
 if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
@@ -112,26 +150,26 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
 
     python3 client_grpc.py \
         --server-addr localhost \
-        --model-name cosyvoice2 \
+        --model-name cosyvoice2_dit \
         --num-tasks $num_task \
         --mode $mode \
-        --use-spk2info-cache $use_spk2info_cache \
         --huggingface-dataset yuekai/seed_tts_cosy2 \
-        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_spk_cache_${use_spk2info_cache}
+        --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_no_att_cnn_cache_new
 fi
 
 if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
   echo "stage 6: Offline inference benchmark"
   n_gpus=1
   datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
-  backend=trtllm # hf, trtllm, vllm
+  backend=trtllm-serve # hf, trtllm, vllm
 
   batch_sizes=(16 8 4 2 1)
+  batch_sizes=(16 8 4 2)
   token2wav_batch_size=1
   for batch_size in ${batch_sizes[@]}; do
     for dataset in ${datasets[@]}; do
     output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
-    CUDA_VISIBLE_DEVICES=0 \
+    CUDA_VISIBLE_DEVICES=1 \
         python3 offline_inference.py \
             --output-dir $output_dir \
             --llm-model-name-or-path $huggingface_model_local_dir \
@@ -147,7 +185,31 @@ fi
 
 if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
 
-   python3 benchmark_streaming_token2wav.py --enable-trt
+   python3 streaming_inference.py
 
 
+fi
+
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+    mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16 
+    
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+    #! /usr/bin/env bash
+    curl http://localhost:8000/v1/chat/completions \
+        -H "Content-Type: application/json" \
+        -d '{
+            "model": "trt_engines_bfloat16",
+            "messages":[{"role": "user", "content": "Where is New York?"},
+                        {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
+            "max_tokens": 512,
+            "temperature": 0.8,
+            "top_p": 0.95,
+            "top_k": 50,
+            "stop": ["<|eos1|>"],
+            "repetition_penalty": 1.2,
+            "stream": false
+        }'
 fi

+ 115 - 0
runtime/triton_trtllm/streaming_inference.py

@@ -0,0 +1,115 @@
+import torch
+import os
+import argparse
+from datasets import load_dataset
+from torch.utils.data import DataLoader
+import numpy as np
+import torchaudio
+import time
+from token2wav_dit import CosyVoice2_Token2Wav
+import soundfile as sf
+
+def collate_fn(batch):
+    ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
+    prompt_speech_tokens_list, prompt_text_list = [], []
+    for i, item in enumerate(batch):
+        generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
+        audio = torch.from_numpy(item['prompt_audio']['array']).float() 
+        prompt_audios_list.append(audio)
+        prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
+        ids.append(item['id'])
+        prompt_speech_tokens_list.append(item['prompt_audio_cosy2_tokens'])
+        prompt_text_list.append(item['prompt_text'])
+
+    return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--enable-trt", action="store_true")
+    parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
+    parser.add_argument("--batch-size", type=int, default=1)
+    parser.add_argument("--output-dir", type=str, default="generated_wavs")
+    parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
+    parser.add_argument("--dataset-name", type=str, default="yuekai/seed_tts_cosy2")
+    return parser.parse_args()
+
+
+def fake_generated_id_iter(generated_speech_tokens_list):
+    for i in range(len(generated_speech_tokens_list)):
+        yield generated_speech_tokens_list[i]
+
+
+
+if __name__ == "__main__":
+    args = get_args()
+    
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+
+    dataset_name = args.dataset_name
+    dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
+    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
+
+    token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
+    
+    flow_pre_lookahead_len = 3
+    CHUNK_SIZE = 25
+    OVERLAP_SIZE = 0
+
+    warmup_times = 3
+    for _ in range(warmup_times):
+        start_time = time.time()
+        for batch in data_loader:
+            tts_speech_list = []
+            ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
+
+            id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
+            # if id != "unseen3_text5":
+            #     continue
+            # else:
+            #     a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt")
+            #     generated_speech_tokens = a["semantic_token_ids_arr"]
+            #     print(generated_speech_tokens)
+            assert prompt_audio_sample_rate == 16000
+
+            prompt_text = prompt_text_list[0]
+            prompt_speech_tokens = prompt_speech_tokens_list[0]
+
+
+            # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens)
+
+            semantic_token_ids_arr, token_offset = [], 0
+            flow_prompt_speech_token_len = len(prompt_speech_tokens)
+    
+            buffer = generated_speech_tokens
+            output_wavs = []
+            while True:
+
+                if len(buffer) >= CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len:
+                    wavs = token2wav_model.forward_streaming(buffer[:CHUNK_SIZE + token2wav_model.flow.pre_lookahead_len], False, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
+                    buffer = buffer[CHUNK_SIZE - OVERLAP_SIZE:]
+
+                    output_wavs.append(wavs)
+
+                else:
+                    wavs = token2wav_model.forward_streaming(buffer, True, request_id=id, speaker_id=f"{id}", prompt_audio=prompt_audio, prompt_audio_sample_rate=prompt_audio_sample_rate)
+                    output_wavs.append(wavs)
+                    break
+
+            for i, wav in enumerate(output_wavs):
+                output_wavs[i] = wav.cpu().numpy().squeeze()
+
+
+            audios = output_wavs            
+            reconstructed_audio = np.concatenate(audios)
+            # Save reconstructed audio
+            sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
+
+
+            print(f"Saved {id}")
+        end_time = time.time()
+
+        if _ == 0:
+            token2wav_model.speaker_cache = {}
+        print(f"Warmup time: {end_time - start_time} seconds")
+