Browse Source

clean code

root 1 month ago
parent
commit
f186ec3338

+ 193 - 196
runtime/triton_trtllm/model_repo/cosyvoice2_dit/1/model.py

@@ -28,9 +28,10 @@ import json
 import math
 import os
 import re
-import threading
 import time
 from typing import Dict, List, Tuple, Optional, Union
+import asyncio
+import httpx
 
 import numpy as np
 import torch
@@ -42,11 +43,30 @@ import torchaudio
 
 
 from matcha.utils.audio import mel_spectrogram
+from datetime import datetime
 
 ORIGINAL_VOCAB_SIZE = 151663
 torch.set_num_threads(1)
 
 
+def parse_speech_token_string(response_text: str) -> List[int]:
+    """
+    Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
+    """
+    speech_tokens = response_text.strip().split('><')
+    if len(speech_tokens) > 1:
+        # Add back the missing '<' and '>' for proper parsing
+        speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
+        speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
+
+    speech_ids = []
+    for token_str in speech_tokens:
+        match = re.match(r'<\|s_(\d+)\|>', token_str)
+        if match:
+            speech_ids.append(int(match.group(1)))
+    return speech_ids
+
+
 class TritonPythonModel:
     """Triton Python model for Spark TTS.
 
@@ -67,6 +87,7 @@ class TritonPythonModel:
         model_params = {k: v["string_value"] for k, v in parameters.items()}
         self.logger.log_info(f"model_params:{model_params}")
         self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
+        # self.dynamic_chunk_strategy = "equal"
         self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
 
         # Initialize tokenizer
@@ -87,92 +108,86 @@ class TritonPythonModel:
             raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
         spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
         self.default_spk_info = spk_info["001"]
-
-    def forward_llm(self, input_ids):
+        self.http_client = httpx.AsyncClient()
+
+    def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
+        """Converts a tensor or list of speech token IDs to a string representation."""
+        if isinstance(speech_tokens, torch.Tensor):
+            # Ensure tensor is on CPU and flattened
+            speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
+
+        speech_id_str = ""
+        for token_id in speech_tokens:
+            # Convert token ID back to the speech number N
+            token_num = token_id - ORIGINAL_VOCAB_SIZE
+            speech_id_str += f"<|s_{token_num}|>"
+        return speech_id_str
+
+    async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
         """
-        Prepares the response from the language model based on the provided
-        inputs. Creates a `pb_utils.InferenceRequest` object with passed
-        `llm_request_inputs` to send to a decoupled TensorRTLLM model.
-        For each response from the language model:
-            - Checks for errors and raise an exception if any are found.
-            - Extracts the "output_ids" tensor from the response.
-            - Determines the finish reason based on the presence of the
-              end-of-sequence token or reaching the maximum length.
-            - Appends the generated token IDs to `output_ids`.
-            - If the finish reason is determined, decodes the output IDs to text
-              and prepares the final response.
-
-        The final response includes the generated text, finish reason,
-        completion tokens, prompt tokens, and total tokens.
-
-        Parameters
-        ----------
-        - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
-
-        Returns
-        -------
-        - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
+        Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
         """
-        # convert input_ids to numpy, with shape [1, sequence_length]
-        input_ids = input_ids.cpu().numpy()
-        max_tokens = 750
-        input_dict = {
-            "request_output_len": np.array([[max_tokens]], dtype=np.int32),
-            "end_id": np.array([[self.eos_token_id]], dtype=np.int32),
-            "pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
-            "streaming": np.array([[self.decoupled]], dtype=np.bool_),
-            "runtime_top_p": np.array([[0.95]], dtype=np.float32),
-            "runtime_top_k": np.array([[50]], dtype=np.int32),
-            "temperature": np.array([[0.8]], dtype=np.float32),
-            "repetition_penalty": np.array([[1.1]], dtype=np.float32),
-            "random_seed": np.array([[42]], dtype=np.uint64),
-            "input_ids": input_ids,
-            "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
-        }
+        full_text = f"{reference_text}{target_text}"
+        prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
 
-        # Convert inputs to Triton tensors
-        input_tensor_list = [
-            pb_utils.Tensor(k, v) for k, v in input_dict.items()
+        chat = [
+            {"role": "user", "content": full_text},
+            {"role": "assistant", "content": prompt_speech_tokens_str}
         ]
 
-        # Create and execute inference request
-        llm_request = pb_utils.InferenceRequest(
-            model_name="tensorrt_llm",
-            requested_output_names=["output_ids", "sequence_length"],
-            inputs=input_tensor_list,
-        )
-
-        llm_responses = llm_request.exec(decoupled=self.decoupled)
-        if self.decoupled:
-            for llm_response in llm_responses:
-                if llm_response.has_error():
-                    raise pb_utils.TritonModelException(llm_response.error().message())
-
-                # Extract and process output
-                output_ids = pb_utils.get_output_tensor_by_name(
-                    llm_response, "output_ids").as_numpy()
-                seq_lens = pb_utils.get_output_tensor_by_name(
-                    llm_response, "sequence_length").as_numpy()
-
-                # Get actual output IDs up to the sequence length
-                actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
-
-                yield actual_output_ids
-        else:
-            llm_response = llm_responses
-            if llm_response.has_error():
-                raise pb_utils.TritonModelException(llm_response.error().message())
-
-            # Extract and process output
-            output_ids = pb_utils.get_output_tensor_by_name(
-                llm_response, "output_ids").as_numpy()
-            seq_lens = pb_utils.get_output_tensor_by_name(
-                llm_response, "sequence_length").as_numpy()
+        payload = {
+            "model": "trt_engines_bfloat16",
+            "messages": chat,
+            "max_tokens": 750,
+            "temperature": 0.8,
+            "top_p": 0.95,
+            "top_k": 50,
+            "repetition_penalty": 1.1,
+            "stop": ["<|eos1|>", "<|eos|>"],
+            "stream": True,
+        }
 
-            # Get actual output IDs up to the sequence length
-            actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
+        api_base = "http://localhost:8000/v1/chat/completions"
+
+        buffer = ""
+        async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
+            print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
+            print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
+            response.raise_for_status()
+            async for line in response.aiter_lines():
+                if line.startswith("data: "):
+                    line_data = line[len("data: "):].strip()
+                    if line_data == "[DONE]":
+                        break
+                    try:
+                        json_data = json.loads(line_data)
+                        content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
+                        if content:
+                            buffer += content
+                            print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
+                            while True:
+                                match = re.search(r"<\|s_(\d+)\|>", buffer)
+                                if not match:
+                                    break
+
+                                token_num = int(match.group(1))
+                                final_id = token_num + ORIGINAL_VOCAB_SIZE
+                                yield final_id
+                                buffer = buffer[match.end():]
+                    except json.JSONDecodeError:
+                        self.logger.log_info(f"Skipping non-JSON line: {line_data}")
+                        continue
+
+        # Process any remaining complete tokens in the buffer after the stream ends
+        while True:
+            match = re.search(r"<\|s_(\d+)\|>", buffer)
+            if not match:
+                break
+            token_num = int(match.group(1))
+            final_id = token_num + ORIGINAL_VOCAB_SIZE
+            yield final_id
+            buffer = buffer[match.end():]
 
-            yield actual_output_ids
 
     def forward_audio_tokenizer(self, wav, wav_len):
         """Forward pass through the audio tokenizer component.
@@ -225,7 +240,7 @@ class TritonPythonModel:
 
         return prompt_spk_embedding
 
-    def forward_token2wav(
+    async def forward_token2wav(
             self,
             index: int,
             target_speech_tokens: torch.Tensor,
@@ -247,17 +262,19 @@ class TritonPythonModel:
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
         finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
         inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
-
+        
         # Create and execute inference request
         inference_request = pb_utils.InferenceRequest(
             model_name='token2wav_dit',
-            requested_output_names=['waveform'],
+            requested_output_names=[
+                "waveform",
+            ],
             inputs=inputs_tensor,
             request_id=request_id,
             parameters={"priority": index+1},
         )
 
-        inference_response = inference_request.exec()
+        inference_response = await inference_request.async_exec()
         if inference_response.has_error():
             raise pb_utils.TritonModelException(inference_response.error().message())
 
@@ -267,14 +284,6 @@ class TritonPythonModel:
 
         return waveform
 
-    def parse_input(self, text, prompt_text, prompt_speech_tokens):
-        total_text = f"{prompt_text}{text}"
-        prompt = self.prompt_template.format(input_text=total_text)
-        input_ids = self.tokenizer.encode(prompt)
-        input_ids = torch.tensor([input_ids], dtype=torch.int32)
-        input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
-        return input_ids
-
     def _extract_speech_feat(self, speech):
         speech_feat = mel_spectrogram(
             speech,
@@ -292,106 +301,75 @@ class TritonPythonModel:
         speech_feat = speech_feat.unsqueeze(dim=0)
         return speech_feat
 
-    def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
-        for generated_ids in generated_ids_iter:
-            generated_ids = generated_ids.tolist()
-            if len(generated_ids) == 0:
-                break
-            semantic_token_ids_arr.extend(generated_ids)
-        llm_is_done_flag[0] = True
+    async def _process_request(self, request):
+        request_id = request.request_id()
+        # Extract input tensors
+        wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
+
+        # Process reference audio through audio tokenizer
+        if wav is not None:
+            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
+            prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
+            prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
+
+            wav_tensor = wav.as_numpy()
+            wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
+            print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
+            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
+            speech_feat = self._extract_speech_feat(prompt_speech_resample)
+            token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
+            prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
+            prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
+
+            reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
+            reference_text = reference_text[0][0].decode('utf-8')
+            # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
+
+            # reference_text = self.default_spk_info["prompt_text"]
+            # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
+            # prompt_speech_feat = None
+            # prompt_spk_embedding = None
 
-    def execute(self, requests):
-        """Execute inference on the batched requests.
+        else:
+            # using pre-cached reference text
+            assert False, "using pre-cached reference text is not supported"
+            reference_text = self.default_spk_info["prompt_text"]
+            prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
+            prompt_speech_feat = None
+            prompt_spk_embedding = None
 
-        Args:
-            requests: List of inference requests
+        target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
+        target_text = target_text[0][0].decode('utf-8')
+        print(f"target_text: {target_text}, time: {datetime.now()}")
 
-        Returns:
-            List of inference responses containing generated audio
-        """
-        responses = []
-
-        for request in requests:
-            request_id = request.request_id()
-            # Extract input tensors
-            wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
-
-            # Process reference audio through audio tokenizer
-            if wav is not None:
-                wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
-                prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
-                prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
-
-                wav_tensor = wav.as_numpy()
-                wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
-                prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
-                speech_feat = self._extract_speech_feat(prompt_speech_resample)
-                token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
-                prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
-                prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
-
-                reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
-                reference_text = reference_text[0][0].decode('utf-8')
-                # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
-
-                # reference_text = self.default_spk_info["prompt_text"]
-                # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
-                # prompt_speech_feat = None
-                # prompt_spk_embedding = None
-
-            else:
-                assert False, "wav is None"
-                # using pre-cached reference text
-                reference_text = self.default_spk_info["prompt_text"]
-                prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
-                prompt_speech_feat = None
-                prompt_spk_embedding = None
-
-            target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
-            target_text = target_text[0][0].decode('utf-8')
-
-            # Prepare prompt for LLM
-            input_ids = self.parse_input(
-                text=target_text,
-                prompt_text=reference_text,
+        if self.decoupled:
+            response_sender = request.get_response_sender()
+
+            semantic_token_ids_arr = []
+            token_offset, chunk_index = 0, 0
+            start_time = time.time()
+            this_token_hop_len = self.token_hop_len
+            print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
+            async for generated_ids in self.forward_llm_async(
+                target_text=target_text,
+                reference_text=reference_text,
                 prompt_speech_tokens=prompt_speech_tokens,
-            )
-
-            # Generate semantic tokens with LLM
-            generated_ids_iter = self.forward_llm(input_ids)
-
-            if self.decoupled:
-                response_sender = request.get_response_sender()
-
-                semantic_token_ids_arr = []
-                llm_is_done_flag = [False]
-
-                llm_thread = threading.Thread(
-                    target=self._llm_gen_thread,
-                    args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
-                )
-
-                llm_thread.start()
-
-                token_offset, chunk_index = 0, 0
-                start_time = time.time()
-                this_token_hop_len = self.token_hop_len
-
+            ):
+                if not generated_ids:
+                    break
+                semantic_token_ids_arr.append(generated_ids)
+                print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
                 while True:
                     pending_num = len(semantic_token_ids_arr) - token_offset
-
-                    if llm_is_done_flag[0]:
-                        break
-
                     if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
                         this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
                         this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
-
-                        sub_tts_speech = self.forward_token2wav(
+                        print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
+                        sub_tts_speech = await self.forward_token2wav(
                             chunk_index,
                             this_tts_speech_token, request_id, wav, wav_len, False
                         )
-
+                        print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
                         audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
                         inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                         response_sender.send(inference_response)
@@ -401,6 +379,8 @@ class TritonPythonModel:
 
                         if self.dynamic_chunk_strategy == "exponential":
                             this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
+                        elif self.dynamic_chunk_strategy == "equal":
+                            this_token_hop_len = self.token_hop_len
                         elif self.dynamic_chunk_strategy == "time_based":
                             # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
                             cost_time = time.time() - start_time
@@ -420,19 +400,36 @@ class TritonPythonModel:
                                     this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
                         chunk_index += 1
                     else:
-                        time.sleep(0.02)
-
-                this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
-                sub_tts_speech = self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
-                audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
-                inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
-                response_sender.send(inference_response)
-
-                llm_thread.join()
-                response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
-                self.logger.log_info("send tritonserver_response_complete_final to end")
-            else:
-                raise NotImplementedError("Decoupled mode is not supported")
-
-        if not self.decoupled:
-            return responses
+                        break
+            
+            this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
+            sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
+            audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
+            inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
+            response_sender.send(inference_response)
+
+            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+            self.logger.log_info("send tritonserver_response_complete_final to end")
+        else:
+            raise NotImplementedError("Decoupled mode is not supported")
+
+    async def execute(self, requests):
+        """Execute inference on the batched requests.
+
+        Args:
+            requests: List of inference requests
+
+        Returns:
+            List of inference responses containing generated audio
+        """
+        tasks = [
+            asyncio.create_task(self._process_request(request))
+            for request in requests
+        ]
+        await asyncio.gather(*tasks)
+        return None
+
+    def finalize(self):
+        self.logger.log_info("Finalizing CosyVoice DIT model")
+        if hasattr(self, "http_client"):
+            asyncio.run(self.http_client.aclose())

+ 0 - 435
runtime/triton_trtllm/model_repo/cosyvoice2_dit/3/model.py

@@ -1,435 +0,0 @@
-# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-#  * Redistributions of source code must retain the above copyright
-#    notice, this list of conditions and the following disclaimer.
-#  * Redistributions in binary form must reproduce the above copyright
-#    notice, this list of conditions and the following disclaimer in the
-#    documentation and/or other materials provided with the distribution.
-#  * Neither the name of NVIDIA CORPORATION nor the names of its
-#    contributors may be used to endorse or promote products derived
-#    from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import json
-import math
-import os
-import re
-import time
-from typing import Dict, List, Tuple, Optional, Union
-import asyncio
-import httpx
-
-import numpy as np
-import torch
-from torch.utils.dlpack import from_dlpack, to_dlpack
-import triton_python_backend_utils as pb_utils
-from transformers import AutoTokenizer
-
-import torchaudio
-
-
-from matcha.utils.audio import mel_spectrogram
-from datetime import datetime
-
-ORIGINAL_VOCAB_SIZE = 151663
-torch.set_num_threads(1)
-
-
-def parse_speech_token_string(response_text: str) -> List[int]:
-    """
-    Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
-    """
-    speech_tokens = response_text.strip().split('><')
-    if len(speech_tokens) > 1:
-        # Add back the missing '<' and '>' for proper parsing
-        speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
-        speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
-
-    speech_ids = []
-    for token_str in speech_tokens:
-        match = re.match(r'<\|s_(\d+)\|>', token_str)
-        if match:
-            speech_ids.append(int(match.group(1)))
-    return speech_ids
-
-
-class TritonPythonModel:
-    """Triton Python model for Spark TTS.
-
-    This model orchestrates the end-to-end TTS pipeline by coordinating
-    between audio tokenizer, LLM, and vocoder components.
-    """
-
-    def initialize(self, args):
-        """Initialize the model.
-
-        Args:
-            args: Dictionary containing model configuration
-        """
-        self.logger = pb_utils.Logger
-        # Parse model parameters
-        self.model_config = json.loads(args['model_config'])
-        parameters = self.model_config['parameters']
-        model_params = {k: v["string_value"] for k, v in parameters.items()}
-        self.logger.log_info(f"model_params:{model_params}")
-        self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential")  # "exponential" or "time_based"
-        # self.dynamic_chunk_strategy = "equal"
-        self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
-
-        # Initialize tokenizer
-        llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
-        self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
-        self.prompt_template = "<|sos|>{input_text}<|task_id|>"
-        self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
-
-        self.device = torch.device("cuda")
-        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
-
-        self.token_frame_rate = 25
-        self.flow_pre_lookahead_len = 3
-        self.token_hop_len = 15
-
-        spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
-        if not os.path.exists(spk_info_path):
-            raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
-        spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
-        self.default_spk_info = spk_info["001"]
-        self.http_client = httpx.AsyncClient()
-
-    def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
-        """Converts a tensor or list of speech token IDs to a string representation."""
-        if isinstance(speech_tokens, torch.Tensor):
-            # Ensure tensor is on CPU and flattened
-            speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
-
-        speech_id_str = ""
-        for token_id in speech_tokens:
-            # Convert token ID back to the speech number N
-            token_num = token_id - ORIGINAL_VOCAB_SIZE
-            speech_id_str += f"<|s_{token_num}|>"
-        return speech_id_str
-
-    async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
-        """
-        Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
-        """
-        full_text = f"{reference_text}{target_text}"
-        prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
-
-        chat = [
-            {"role": "user", "content": full_text},
-            {"role": "assistant", "content": prompt_speech_tokens_str}
-        ]
-
-        payload = {
-            "model": "trt_engines_bfloat16",
-            "messages": chat,
-            "max_tokens": 750,
-            "temperature": 0.8,
-            "top_p": 0.95,
-            "top_k": 50,
-            "repetition_penalty": 1.1,
-            "stop": ["<|eos1|>", "<|eos|>"],
-            "stream": True,
-        }
-
-        api_base = "http://localhost:8000/v1/chat/completions"
-
-        buffer = ""
-        async with self.http_client.stream("POST", api_base, json=payload, timeout=None) as response:
-            print(f"start httpx.AsyncClient, target_text: {target_text[:5]}, time: {datetime.now()}")
-            print(f"start response.aiter_lines, target_text: {target_text[:5]}, time: {datetime.now()}")
-            response.raise_for_status()
-            async for line in response.aiter_lines():
-                if line.startswith("data: "):
-                    line_data = line[len("data: "):].strip()
-                    if line_data == "[DONE]":
-                        break
-                    try:
-                        json_data = json.loads(line_data)
-                        content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
-                        if content:
-                            buffer += content
-                            print(f"buffer: {buffer}, target_text: {target_text[:5]}, time: {datetime.now()}")
-                            while True:
-                                match = re.search(r"<\|s_(\d+)\|>", buffer)
-                                if not match:
-                                    break
-
-                                token_num = int(match.group(1))
-                                final_id = token_num + ORIGINAL_VOCAB_SIZE
-                                yield final_id
-                                buffer = buffer[match.end():]
-                    except json.JSONDecodeError:
-                        self.logger.log_info(f"Skipping non-JSON line: {line_data}")
-                        continue
-
-        # Process any remaining complete tokens in the buffer after the stream ends
-        while True:
-            match = re.search(r"<\|s_(\d+)\|>", buffer)
-            if not match:
-                break
-            token_num = int(match.group(1))
-            final_id = token_num + ORIGINAL_VOCAB_SIZE
-            yield final_id
-            buffer = buffer[match.end():]
-
-
-    def forward_audio_tokenizer(self, wav, wav_len):
-        """Forward pass through the audio tokenizer component.
-
-        Args:
-            wav: Input waveform tensor
-            wav_len: Waveform length tensor
-
-        Returns:
-            Tuple of global and semantic tokens
-        """
-        inference_request = pb_utils.InferenceRequest(
-            model_name='audio_tokenizer',
-            requested_output_names=['prompt_speech_tokens'],
-            inputs=[wav, wav_len]
-        )
-
-        inference_response = inference_request.exec()
-        if inference_response.has_error():
-            raise pb_utils.TritonModelException(inference_response.error().message())
-
-        # Extract and convert output tensors
-        prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
-        prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
-
-        return prompt_speech_tokens
-
-    def forward_speaker_embedding(self, wav):
-        """Forward pass through the speaker embedding component.
-
-        Args:
-            wav: Input waveform tensor
-
-        Returns:
-            Prompt speaker embedding tensor
-        """
-        inference_request = pb_utils.InferenceRequest(
-            model_name='speaker_embedding',
-            requested_output_names=['prompt_spk_embedding'],
-            inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
-        )
-
-        inference_response = inference_request.exec()
-        if inference_response.has_error():
-            raise pb_utils.TritonModelException(inference_response.error().message())
-
-        # Extract and convert output tensors
-        prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
-        prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
-
-        return prompt_spk_embedding
-
-    async def forward_token2wav(
-            self,
-            index: int,
-            target_speech_tokens: torch.Tensor,
-            request_id: str,
-            reference_wav: object,
-            reference_wav_len: object,
-            finalize: bool = None) -> torch.Tensor:
-        """Forward pass through the vocoder component.
-
-        Args:
-            prompt_speech_tokens: Prompt speech tokens tensor
-            prompt_speech_feat: Prompt speech feat tensor
-            prompt_spk_embedding: Prompt spk embedding tensor
-            target_speech_tokens: Target speech tokens tensor
-
-        Returns:
-            Generated waveform tensor
-        """
-        target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
-        finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
-        inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
-        
-        # Create and execute inference request
-        inference_request = pb_utils.InferenceRequest(
-            model_name='token2wav_dit',
-            requested_output_names=[
-                "waveform",
-            ],
-            inputs=inputs_tensor,
-            request_id=request_id,
-            parameters={"priority": index+1},
-        )
-
-        inference_response = await inference_request.async_exec()
-        if inference_response.has_error():
-            raise pb_utils.TritonModelException(inference_response.error().message())
-
-        # Extract and convert output waveform
-        waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
-        waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
-
-        return waveform
-
-    def _extract_speech_feat(self, speech):
-        speech_feat = mel_spectrogram(
-            speech,
-            n_fft=1920,
-            num_mels=80,
-            sampling_rate=24000,
-            hop_size=480,
-            win_size=1920,
-            fmin=0,
-            fmax=8000).squeeze(
-            dim=0).transpose(
-            0,
-            1).to(
-                self.device)
-        speech_feat = speech_feat.unsqueeze(dim=0)
-        return speech_feat
-
-    async def _process_request(self, request):
-        request_id = request.request_id()
-        # Extract input tensors
-        wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
-
-        # Process reference audio through audio tokenizer
-        if wav is not None:
-            wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
-            prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
-            prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
-
-            wav_tensor = wav.as_numpy()
-            wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
-            print(f"wav_tensor: {wav_tensor.shape}, time: {datetime.now()}")
-            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
-            speech_feat = self._extract_speech_feat(prompt_speech_resample)
-            token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
-            prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
-            prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
-
-            reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
-            reference_text = reference_text[0][0].decode('utf-8')
-            # prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
-
-            # reference_text = self.default_spk_info["prompt_text"]
-            # prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
-            # prompt_speech_feat = None
-            # prompt_spk_embedding = None
-
-        else:
-            # using pre-cached reference text
-            assert False, "using pre-cached reference text is not supported"
-            reference_text = self.default_spk_info["prompt_text"]
-            prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
-            prompt_speech_feat = None
-            prompt_spk_embedding = None
-
-        target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
-        target_text = target_text[0][0].decode('utf-8')
-        print(f"target_text: {target_text}, time: {datetime.now()}")
-
-        if self.decoupled:
-            response_sender = request.get_response_sender()
-
-            semantic_token_ids_arr = []
-            token_offset, chunk_index = 0, 0
-            start_time = time.time()
-            this_token_hop_len = self.token_hop_len
-            print(f"start forward_llm_async, target_text: {target_text[:5]}, time: {datetime.now()}")
-            async for generated_ids in self.forward_llm_async(
-                target_text=target_text,
-                reference_text=reference_text,
-                prompt_speech_tokens=prompt_speech_tokens,
-            ):
-                if not generated_ids:
-                    break
-                semantic_token_ids_arr.append(generated_ids)
-                print(f"generated_ids: {generated_ids}, target_text: {target_text[:5]}, time: {datetime.now()}")
-                while True:
-                    pending_num = len(semantic_token_ids_arr) - token_offset
-                    if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
-                        this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
-                        this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
-                        print(f"chunk_index: {chunk_index}, target_text: {target_text[:5]}, time: {datetime.now()}")
-                        sub_tts_speech = await self.forward_token2wav(
-                            chunk_index,
-                            this_tts_speech_token, request_id, wav, wav_len, False
-                        )
-                        print(f"finish token2wav, target_text: {target_text[:5]}, time: {datetime.now()}")
-                        audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
-                        inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
-                        response_sender.send(inference_response)
-
-                        token_offset += this_token_hop_len
-                        self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
-
-                        if self.dynamic_chunk_strategy == "exponential":
-                            this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
-                        elif self.dynamic_chunk_strategy == "equal":
-                            this_token_hop_len = self.token_hop_len
-                        elif self.dynamic_chunk_strategy == "time_based":
-                            # see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
-                            cost_time = time.time() - start_time
-                            duration = token_offset / self.token_frame_rate
-                            if chunk_index > 0 and cost_time > 0:
-                                avg_chunk_processing_time = cost_time / (chunk_index + 1)
-                                if avg_chunk_processing_time > 0:
-                                    multiples = (duration - cost_time) / avg_chunk_processing_time
-                                    self.logger.log_info(f"multiples: {multiples}")
-                                    next_pending_num = len(semantic_token_ids_arr) - token_offset
-                                    if multiples > 4:
-                                        this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
-                                    elif multiples > 2:
-                                        this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
-                                    else:
-                                        this_token_hop_len = self.token_hop_len
-                                    this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
-                        chunk_index += 1
-                    else:
-                        break
-            
-            this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
-            sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
-            audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
-            inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
-            response_sender.send(inference_response)
-
-            response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
-            self.logger.log_info("send tritonserver_response_complete_final to end")
-        else:
-            raise NotImplementedError("Decoupled mode is not supported")
-
-    async def execute(self, requests):
-        """Execute inference on the batched requests.
-
-        Args:
-            requests: List of inference requests
-
-        Returns:
-            List of inference responses containing generated audio
-        """
-        tasks = [
-            asyncio.create_task(self._process_request(request))
-            for request in requests
-        ]
-        await asyncio.gather(*tasks)
-        return None
-
-    def finalize(self):
-        self.logger.log_info("Finalizing CosyVoice DIT model")
-        if hasattr(self, "http_client"):
-            asyncio.run(self.http_client.aclose())

+ 6 - 4
runtime/triton_trtllm/offline_inference.py

@@ -47,8 +47,6 @@ import requests
 import asyncio
 import httpx
 
-from token2wav import CosyVoice2_Token2Wav
-
 sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
 try:
     torch.multiprocessing.set_start_method("spawn")
@@ -367,7 +365,12 @@ def main(args):
         runner = None
     else:
         raise ValueError(f"Unsupported backend: {args.backend}")
-
+    
+    if 'Step-Audio-2-mini' in args.token2wav_path:
+        from token2wav_dit import CosyVoice2_Token2Wav
+    else:
+        assert 'CosyVoice2-0.5B' in args.token2wav_path
+        from token2wav import CosyVoice2_Token2Wav
     token2wav_model = CosyVoice2_Token2Wav(
         model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
     )
@@ -589,7 +592,6 @@ def main(args):
                         t2w_prompt_audios_list,
                         t2w_prompt_audios_sample_rate,
                     )
-                    torch.cuda.synchronize()
                     token2wav_end_time = time.time()
                     total_token2wav_time += (token2wav_end_time - token2wav_start_time)
 

+ 69 - 161
runtime/triton_trtllm/run_stepaudio2_dit_token2wav.sh

@@ -1,28 +1,33 @@
 #!/bin/bash
 # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang)
 export CUDA_VISIBLE_DEVICES=0
-cosyvoice_path=/workspace/CosyVoice
+# cosyvoice_path=/workspace/CosyVoice
 cosyvoice_path=/workspace_yuekai/tts/CosyVoice
 stepaudio2_path=/workspace_yuekai/tts/Step-Audio2
+
 export PYTHONPATH=${stepaudio2_path}:$PYTHONPATH
 export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH
 export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH
+
 stage=$1
 stop_stage=$2
-N_GPUS=2 # set the number of GPUs to use
-
 
 huggingface_model_local_dir=./cosyvoice2_llm
 model_scope_model_local_dir=./CosyVoice2-0.5B
+step_audio_model_dir=./Step-Audio-2-mini
+
 trt_dtype=bfloat16
 trt_weights_dir=./trt_weights_${trt_dtype}
 trt_engines_dir=./trt_engines_${trt_dtype}
 
 model_repo=./model_repo_cosyvoice2_dit
-
-use_spk2info_cache=False
+bls_instance_num=4
 
 if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+
+    echo "Cloning Step-Audio2-mini"
+    git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt $stepaudio2_path
+
     echo "Cloning CosyVoice"
     git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path
     cd $cosyvoice_path
@@ -35,8 +40,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
     # see https://github.com/nvidia-china-sae/mair-hub/blob/main/rl-tutorial/cosyvoice_llm/pretrained_to_huggingface.py
     huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm
     modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir
-    # download spk2info.pt to directly use cached speech tokens, speech feats, and embeddings
-    wget https://raw.githubusercontent.com/qi-hua/async_cosyvoice/main/CosyVoice2-0.5B/spk2info.pt -O $model_scope_model_local_dir/spk2info.pt
+
+    echo "Step-Audio2-mini"
+    huggingface-cli download --local-dir $step_audio_model_dir stepfun-ai/Step-Audio-2-mini
+    cd $stepaudio2_path/token2wav
+    wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O flow.decoder.estimator.fp32.dynamic_batch.onnx
+    wget https://huggingface.co/yuekai/cosyvoice2_dit_flow_matching_onnx/resolve/main/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx -O flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx
+    cd -
 fi
 
 
@@ -60,40 +70,6 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
                     --engine_dir=$trt_engines_dir  || exit 1
 fi
 
-
-# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
-#     echo "Creating model repository"
-#     rm -rf $model_repo
-#     mkdir -p $model_repo
-#     cosyvoice2_dir="cosyvoice2_dit"
-#     token2wav_dir="token2wav_dit"
-
-#     cp -r ./model_repo/${cosyvoice2_dir} $model_repo
-#     cp -r ./model_repo/tensorrt_llm $model_repo
-#     cp -r ./model_repo/${token2wav_dir} $model_repo
-#     #if [ $use_spk2info_cache == "False" ]; then
-#         cp -r ./model_repo/audio_tokenizer $model_repo
-#         cp -r ./model_repo/speaker_embedding $model_repo
-#     #fi
-
-#     ENGINE_PATH=$trt_engines_dir
-#     MAX_QUEUE_DELAY_MICROSECONDS=0
-#     MODEL_DIR=$model_scope_model_local_dir
-#     LLM_TOKENIZER_DIR=$huggingface_model_local_dir
-#     BLS_INSTANCE_NUM=1
-#     TRITON_MAX_BATCH_SIZE=16
-#     DECOUPLED_MODE=True # True for streaming, False for offline
-#     STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
-
-#     python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-#     python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-#     python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
-#     #if [ $use_spk2info_cache == "False" ]; then
-#         python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-#         python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-#     #fi
-# fi
-
 if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
     echo "Creating model repository async mode"
     rm -rf $model_repo
@@ -102,122 +78,75 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
     token2wav_dir="token2wav_dit"
 
     cp -r ./model_repo/${cosyvoice2_dir} $model_repo
-    cp -r ./model_repo/tensorrt_llm $model_repo
     cp -r ./model_repo/${token2wav_dir} $model_repo
-    #if [ $use_spk2info_cache == "False" ]; then
-        cp -r ./model_repo/audio_tokenizer $model_repo
-        cp -r ./model_repo/speaker_embedding $model_repo
-    #fi
+    cp -r ./model_repo/audio_tokenizer $model_repo
+    cp -r ./model_repo/speaker_embedding $model_repo
+
 
     ENGINE_PATH=$trt_engines_dir
     MAX_QUEUE_DELAY_MICROSECONDS=0
     MODEL_DIR=$model_scope_model_local_dir
     LLM_TOKENIZER_DIR=$huggingface_model_local_dir
-    BLS_INSTANCE_NUM=4
+    BLS_INSTANCE_NUM=$bls_instance_num
     TRITON_MAX_BATCH_SIZE=1
-    DECOUPLED_MODE=True # True for streaming, False for offline
-    STEP_AUDIO_MODEL_DIR=/workspace_yuekai/tts/CosyVoice/runtime/triton_trtllm/Step-Audio-2-mini/token2wav
+    DECOUPLED_MODE=True
+    STEP_AUDIO_MODEL_DIR=$step_audio_model_dir/token2wav
 
     python3 scripts/fill_template.py -i ${model_repo}/${token2wav_dir}/config.pbtxt model_dir:${STEP_AUDIO_MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
     python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-    python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
-    #if [ $use_spk2info_cache == "False" ]; then
-        python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-        python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
-    #fi
-    rm -rf $model_repo/tensorrt_llm
-    # mv $model_repo/cosyvoice2_dit/1 $model_repo/cosyvoice2_dit/4
+    python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+    python3 scripts/fill_template.py -i ${model_repo}/speaker_embedding/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
+
 fi
 
 if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
-   echo "Starting Triton server on $N_GPUS GPUs"
-   for i in $(seq 0 $(($N_GPUS - 1))); do
-       echo "Starting server on GPU $i"
-       http_port=$((19000 + $i))
-       grpc_port=$((18000 + $i))
-       metrics_port=$((17000 + $i))
-       CUDA_VISIBLE_DEVICES=$i tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
-   done
-
-   echo "Servers are running in the background. Press Ctrl+C to stop them and the script."
+   echo "Starting Token2wav Triton server and Cosyvoice2 llm using trtllm-serve"
+   tritonserver --model-repository $model_repo --http-port 18000 &
+   mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16  --kv_cache_free_gpu_memory_fraction 0.4 &
    wait
+    # Test using curl
+    # curl http://localhost:8000/v1/chat/completions \
+    #     -H "Content-Type: application/json" \
+    #     -d '{
+    #         "model": "trt_engines_bfloat16",
+    #         "messages":[{"role": "user", "content": "Where is New York?"},
+    #                     {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
+    #         "max_tokens": 512,
+    #         "temperature": 0.8,
+    #         "top_p": 0.95,
+    #         "top_k": 50,
+    #         "stop": ["<|eos1|>"],
+    #         "repetition_penalty": 1.2,
+    #         "stream": false
+    #     }'
 fi
 
-if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
-   echo "Starting Triton server on $N_GPUS GPUs"
-   N_GPUS=1
-   for i in $(seq 0 $(($N_GPUS - 1))); do
-       echo "Starting server on GPU $i"
-       http_port=$((19000 + $i))
-       grpc_port=$((18000 + $i))
-       metrics_port=$((17000 + $i))
-       CUDA_VISIBLE_DEVICES=0 tritonserver --model-repository $model_repo --http-port $http_port --grpc-port $grpc_port --metrics-port $metrics_port &
-   done
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+    echo "Running benchmark client"
+    num_task=4
+    mode=streaming
+    BLS_INSTANCE_NUM=$bls_instance_num
 
-   echo "Servers are running in the background. Press Ctrl+C to stop them and the script."
-   wait
-fi
+    python3 client_grpc.py \
+        --server-addr localhost \
+        --server-port 8001 \
+        --model-name cosyvoice2_dit \
+        --num-tasks $num_task \
+        --mode $mode \
+        --huggingface-dataset yuekai/seed_tts_cosy2 \
+        --log-dir ./log_single_gpu_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}
 
-if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
-    echo "Single request test http, only work for offline TTS mode"
-    python3 client_http.py \
-        --reference-audio ./assets/prompt_audio.wav \
-        --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
-        --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
-        --model-name cosyvoice2
 fi
 
 if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
-    echo "Running benchmark client grpc on $N_GPUS GPUs"
-    num_task=1
-
-    mode=streaming
-    BLS_INSTANCE_NUM=4
-
-    for i in $(seq 0 $(($N_GPUS - 1))); do
-        grpc_port=$((18000 + $i))
-        echo "Running client for server on localhost:$grpc_port"
-        python3 client_grpc.py \
-            --server-addr localhost \
-            --server-port $grpc_port \
-            --model-name cosyvoice2_dit \
-            --num-tasks $num_task \
-            --mode $mode \
-            --huggingface-dataset yuekai/seed_tts_cosy2 \
-            --log-dir ./log_debug_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM}_gpu${i} &
-    done
-    wait
-fi
-if [ $stage -le 50 ] && [ $stop_stage -ge 50 ]; then
-    echo "Running benchmark client grpc on $N_GPUS GPUs"
-    num_task=4
-    N_GPUS=1
-    mode=streaming
-    BLS_INSTANCE_NUM=4
+  echo "stage 5: Offline TTS (Cosyvoice2 LLM + Step-Audio2-mini DiT Token2Wav) inference using a single python script"
 
-    for i in $(seq 0 $(($N_GPUS - 1))); do
-        grpc_port=$((18000 + $i))
-        echo "Running client for server on localhost:$grpc_port"
-        python3 client_grpc.py \
-            --server-addr localhost \
-            --server-port $grpc_port \
-            --model-name cosyvoice2_dit \
-            --num-tasks $num_task \
-            --mode $mode \
-            --huggingface-dataset yuekai/seed_tts_cosy2 \
-            --log-dir ./log_single_card_concurrent_tasks_${num_task}_${mode}_bls_${BLS_INSTANCE_NUM} &
-    done
-    wait
-fi
-if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
-  echo "stage 6: Offline inference benchmark"
-  n_gpus=1
   datasets=(wenetspeech4tts) # wenetspeech4tts, test_zh, zero_shot_zh
-  backend=trtllm-serve # hf, trtllm, vllm
+  backend=trtllm # hf, trtllm, vllm, trtllm-serve
 
-  batch_sizes=(16 8 4 2 1)
-  batch_sizes=(16 8 4 2)
+  batch_sizes=(16)
   token2wav_batch_size=1
+
   for batch_size in ${batch_sizes[@]}; do
     for dataset in ${datasets[@]}; do
     output_dir=./${dataset}_${backend}_llm_batch_size_${batch_size}_token2wav_batch_size_${token2wav_batch_size}
@@ -225,7 +154,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
         python3 offline_inference.py \
             --output-dir $output_dir \
             --llm-model-name-or-path $huggingface_model_local_dir \
-            --token2wav-path $model_scope_model_local_dir \
+            --token2wav-path $step_audio_model_dir/token2wav \
             --backend $backend \
             --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
             --engine-dir $trt_engines_dir \
@@ -234,34 +163,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
   done
 fi
 
-
-if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
-
-   CUDA_VISIBLE_DEVICES=2 python3 streaming_inference.py --enable-trt --strategy exponential
-
-
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+   echo "Running Step-Audio2-mini DiT Token2Wav inference using a single python script"
+   export CUDA_VISIBLE_DEVICES=1
+   # Note: Using pre-computed cosyvoice2 tokens
+   python3 streaming_inference.py --enable-trt --strategy equal # equal, exponential
+   # Offline Token2wav inference
+   # python3 token2wav_dit.py --enable-trt
 fi
 
 
-if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
-    CUDA_VISIBLE_DEVICES=0 mpirun -np 1 --allow-run-as-root --oversubscribe trtllm-serve serve --tokenizer $huggingface_model_local_dir $trt_engines_dir --max_batch_size 16  --kv_cache_free_gpu_memory_fraction 0.4
-    
-fi
-
-if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
-    #! /usr/bin/env bash
-    curl http://localhost:8000/v1/chat/completions \
-        -H "Content-Type: application/json" \
-        -d '{
-            "model": "trt_engines_bfloat16",
-            "messages":[{"role": "user", "content": "Where is New York?"},
-                        {"role": "assistant", "content": "<|s_1708|><|s_2050|><|s_2159|>"}],
-            "max_tokens": 512,
-            "temperature": 0.8,
-            "top_p": 0.95,
-            "top_k": 50,
-            "stop": ["<|eos1|>"],
-            "repetition_penalty": 1.2,
-            "stream": false
-        }'
-fi

+ 9 - 15
runtime/triton_trtllm/streaming_inference.py

@@ -54,7 +54,7 @@ if __name__ == "__main__":
     token2wav_model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt, streaming=True)
     
     flow_pre_lookahead_len = 3
-    CHUNK_SIZE = 15
+    CHUNK_SIZE = 25
     token_frame_rate = 25
     OVERLAP_SIZE = 0
 
@@ -67,20 +67,12 @@ if __name__ == "__main__":
             ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, prompt_speech_tokens_list, prompt_text_list = batch
 
             id, generated_speech_tokens, prompt_audio, prompt_audio_sample_rate = ids[0], generated_speech_tokens_list[0], prompt_audios_list[0], prompt_audios_sample_rate[0]
-            # if id != "unseen3_text5":
-            #     continue
-            # else:
-            #     a = torch.load("semantic_token_ids_arr_debug_871e2b90-42a7-4829-957c-b45e6a96fdb2.pt")
-            #     generated_speech_tokens = a["semantic_token_ids_arr"]
-            #     print(generated_speech_tokens)
+
             assert prompt_audio_sample_rate == 16000
 
             prompt_text = prompt_text_list[0]
             prompt_speech_tokens = prompt_speech_tokens_list[0]
 
-
-            # generated_ids_iter = fake_generated_id_iter(generated_speech_tokens)
-
             semantic_token_ids_arr, token_offset = [], 0
             flow_prompt_speech_token_len = len(prompt_speech_tokens)
     
@@ -114,14 +106,16 @@ if __name__ == "__main__":
 
             audios = output_wavs            
             reconstructed_audio = np.concatenate(audios)
-            # Save reconstructed audio
             sf.write(os.path.join(args.output_dir, f"{id}.wav"), reconstructed_audio, 24000, "PCM_16")
 
-
-            print(f"Saved {id}")
         end_time = time.time()
 
         if _ == 0:
             token2wav_model.speaker_cache = {}
-        print(f"Warmup time: {end_time - start_time} seconds")
-        print(f"Total forward count: {total_forward_count}")
+            print(f"Warmup time: {end_time - start_time} seconds")
+            print("clear speaker cache")
+        elif _ == 1:
+            print(f"Cost time without speaker cache: {end_time - start_time} seconds")
+        else:
+            print(f"Cost time with speaker cache: {end_time - start_time} seconds")
+            print(f"Total flow matching forward calls: {total_forward_count}")