root 4 bulan lalu
induk
melakukan
07cbc51cd1

+ 52 - 49
runtime/triton_trtllm/client_grpc.py

@@ -1,4 +1,3 @@
-#!/usr/bin/env python3
 # Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
 #                2023  Nvidia              (authors: Yuekai Zhang)
 #                2023  Recurrent.ai        (authors: Songtao Shi)
@@ -46,7 +45,7 @@ import asyncio
 import json
 import queue  # Added
 import uuid  # Added
-import functools # Added
+import functools  # Added
 
 import os
 import time
@@ -56,9 +55,9 @@ from pathlib import Path
 import numpy as np
 import soundfile as sf
 import tritonclient
-import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
-import tritonclient.grpc as grpcclient_sync # Added sync client import
-from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
+import tritonclient.grpc.aio as grpcclient_aio  # Renamed original import
+import tritonclient.grpc as grpcclient_sync  # Added sync client import
+from tritonclient.utils import np_to_triton_dtype, InferenceServerException  # Added InferenceServerException
 
 
 # --- Added UserData and callback ---
@@ -76,9 +75,10 @@ class UserData:
             return self._first_chunk_time - self._start_time
         return None
 
+
 def callback(user_data, result, error):
     if user_data._first_chunk_time is None and not error:
-        user_data._first_chunk_time = time.time() # Record time of first successful chunk
+        user_data._first_chunk_time = time.time()  # Record time of first successful chunk
     if error:
         user_data._completed_requests.put(error)
     else:
@@ -206,8 +206,11 @@ def get_args():
         "--model-name",
         type=str,
         default="f5_tts",
-        choices=["f5_tts", "spark_tts", "cosyvoice2"],
-        help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
+        choices=[
+            "f5_tts",
+            "spark_tts",
+            "cosyvoice2"],
+        help="triton model_repo module name to request",
     )
 
     parser.add_argument(
@@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000):
         waveform = resample(waveform, num_samples)
     return waveform, target_sample_rate
 
+
 def prepare_request_input_output(
-    protocol_client, # Can be grpcclient_aio or grpcclient_sync
+    protocol_client,  # Can be grpcclient_aio or grpcclient_sync
     waveform,
     reference_text,
     target_text,
     sample_rate=16000,
-    padding_duration: int = None # Optional padding for offline mode
+    padding_duration: int = None  # Optional padding for offline mode
 ):
     """Prepares inputs for Triton inference (offline or streaming)."""
     assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -291,9 +295,9 @@ def prepare_request_input_output(
         # Estimate target duration based on text length ratio (crude estimation)
         # Avoid division by zero if reference_text is empty
         if reference_text:
-             estimated_target_duration = duration / len(reference_text) * len(target_text)
+            estimated_target_duration = duration / len(reference_text) * len(target_text)
         else:
-             estimated_target_duration = duration # Assume target duration similar to reference if no text
+            estimated_target_duration = duration  # Assume target duration similar to reference if no text
 
         # Calculate required samples based on estimated total duration
         required_total_samples = padding_duration * sample_rate * (
@@ -329,6 +333,7 @@ def prepare_request_input_output(
 
     return inputs, outputs
 
+
 def run_sync_streaming_inference(
     sync_triton_client: tritonclient.grpc.InferenceServerClient,
     model_name: str,
@@ -342,7 +347,7 @@ def run_sync_streaming_inference(
 ):
     """Helper function to run the blocking sync streaming call."""
     start_time_total = time.time()
-    user_data.record_start_time() # Record start time for first chunk latency calculation
+    user_data.record_start_time()  # Record start time for first chunk latency calculation
 
     # Establish stream
     sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
@@ -360,11 +365,11 @@ def run_sync_streaming_inference(
     audios = []
     while True:
         try:
-            result = user_data._completed_requests.get() # Add timeout
+            result = user_data._completed_requests.get()  # Add timeout
             if isinstance(result, InferenceServerException):
                 print(f"Received InferenceServerException: {result}")
                 sync_triton_client.stop_stream()
-                return None, None, None # Indicate error
+                return None, None, None  # Indicate error
             # Get response metadata
             response = result.get_response()
             final = response.parameters["triton_final_response"].bool_param
@@ -372,15 +377,15 @@ def run_sync_streaming_inference(
                 break
 
             audio_chunk = result.as_numpy("waveform").reshape(-1)
-            if audio_chunk.size > 0: # Only append non-empty chunks
-                 audios.append(audio_chunk)
+            if audio_chunk.size > 0:  # Only append non-empty chunks
+                audios.append(audio_chunk)
             else:
                 print("Warning: received empty audio chunk.")
 
         except queue.Empty:
             print(f"Timeout waiting for response for request id {request_id}")
             sync_triton_client.stop_stream()
-            return None, None, None # Indicate error
+            return None, None, None  # Indicate error
 
     sync_triton_client.stop_stream()
     end_time_total = time.time()
@@ -398,19 +403,19 @@ def run_sync_streaming_inference(
         # Simplified reconstruction based on client_grpc_streaming.py
         if not audios:
             print("Warning: No audio chunks received.")
-            reconstructed_audio = np.array([], dtype=np.float32) # Empty array
+            reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
         elif len(audios) == 1:
             reconstructed_audio = audios[0]
         else:
-            reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
+            reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
             for i in range(1, len(audios)):
-                 # Cross-fade section
-                 cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
-                                        audios[i - 1][-cross_fade_samples:] * fade_out)
-                 # Middle section of the current chunk
-                 middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
-                 # Concatenate
-                 reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
+                # Cross-fade section
+                cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
+                                       audios[i - 1][-cross_fade_samples:] * fade_out)
+                # Middle section of the current chunk
+                middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
+                # Concatenate
+                reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
             # Add the last part of the final chunk
             reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
 
@@ -421,11 +426,11 @@ def run_sync_streaming_inference(
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
         else:
             print("Warning: No audio chunks received or reconstructed.")
-            actual_duration = 0 # Set duration to 0 if no audio
+            actual_duration = 0  # Set duration to 0 if no audio
 
     else:
-         print("Warning: No audio chunks received.")
-         actual_duration = 0
+        print("Warning: No audio chunks received.")
+        actual_duration = 0
 
     return total_request_latency, first_chunk_latency, actual_duration
 
@@ -433,7 +438,7 @@ def run_sync_streaming_inference(
 async def send_streaming(
     manifest_item_list: list,
     name: str,
-    server_url: str, # Changed from sync_triton_client
+    server_url: str,  # Changed from sync_triton_client
     protocol_client: types.ModuleType,
     log_interval: int,
     model_name: str,
@@ -445,11 +450,11 @@ async def send_streaming(
     total_duration = 0.0
     latency_data = []
     task_id = int(name[5:])
-    sync_triton_client = None # Initialize client variable
+    sync_triton_client = None  # Initialize client variable
 
-    try: # Wrap in try...finally to ensure client closing
+    try:  # Wrap in try...finally to ensure client closing
         print(f"{name}: Initializing sync client for streaming...")
-        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
+        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)  # Create client here
 
         print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
         for i, item in enumerate(manifest_item_list):
@@ -491,8 +496,7 @@ async def send_streaming(
                     latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
                     total_duration += actual_duration
                 else:
-                     print(f"{name}: Item {i} failed.")
-
+                    print(f"{name}: Item {i} failed.")
 
             except FileNotFoundError:
                 print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
@@ -501,8 +505,7 @@ async def send_streaming(
                 import traceback
                 traceback.print_exc()
 
-
-    finally: # Ensure client is closed
+    finally:  # Ensure client is closed
         if sync_triton_client:
             try:
                 print(f"{name}: Closing sync client...")
@@ -510,10 +513,10 @@ async def send_streaming(
             except Exception as e:
                 print(f"{name}: Error closing sync client: {e}")
 
-
     print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
     return total_duration, latency_data
 
+
 async def send(
     manifest_item_list: list,
     name: str,
@@ -605,6 +608,7 @@ def split_data(data, k):
 
     return result
 
+
 async def main():
     args = get_args()
     url = f"{args.server_addr}:{args.server_port}"
@@ -622,7 +626,7 @@ async def main():
         # Use the sync client for streaming tasks, handled via asyncio.to_thread
         # We will create one sync client instance PER TASK inside send_streaming.
         # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
-        protocol_client = grpcclient_sync # protocol client for input prep
+        protocol_client = grpcclient_sync  # protocol client for input prep
     else:
         raise ValueError(f"Invalid mode: {args.mode}")
     # --- End Client Initialization ---
@@ -682,11 +686,11 @@ async def main():
                 )
             )
         elif args.mode == "streaming":
-             task = asyncio.create_task(
+            task = asyncio.create_task(
                 send_streaming(
                     manifest_item_list[i],
                     name=f"task-{i}",
-                    server_url=url, # Pass URL instead of client
+                    server_url=url,  # Pass URL instead of client
                     protocol_client=protocol_client,
                     log_interval=args.log_interval,
                     model_name=args.model_name,
@@ -709,16 +713,15 @@ async def main():
     for ans in ans_list:
         if ans:
             total_duration += ans[0]
-            latency_data.extend(ans[1]) # Use extend for list of lists
+            latency_data.extend(ans[1])  # Use extend for list of lists
         else:
-             print("Warning: A task returned None, possibly due to an error.")
-
+            print("Warning: A task returned None, possibly due to an error.")
 
     if total_duration == 0:
         print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
         rtf = float('inf')
     else:
-         rtf = elapsed / total_duration
+        rtf = elapsed / total_duration
 
     s = f"Mode: {args.mode}\n"
     s += f"RTF: {rtf:.4f}\n"
@@ -759,7 +762,7 @@ async def main():
                 s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
                 s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
             else:
-                 s += "No total request latency data collected.\n"
+                s += "No total request latency data collected.\n"
 
             s += "\n--- First Chunk Latency ---\n"
             if first_chunk_latency_list:
@@ -772,7 +775,7 @@ async def main():
                 s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
                 s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
             else:
-                 s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
+                s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
     else:
         s += "No latency data collected.\n"
     # --- End Statistics Reporting ---
@@ -785,7 +788,7 @@ async def main():
     elif args.reference_audio:
         name = Path(args.reference_audio).stem
     else:
-        name = "results" # Default name if no manifest/split/audio provided
+        name = "results"  # Default name if no manifest/split/audio provided
     with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
         f.write(s)
 

+ 13 - 9
runtime/triton_trtllm/client_http.py

@@ -29,6 +29,7 @@ import json
 import numpy as np
 import argparse
 
+
 def get_args():
     parser = argparse.ArgumentParser(
         formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -67,9 +68,10 @@ def get_args():
         type=str,
         default="spark_tts",
         choices=[
-            "f5_tts", "spark_tts", "cosyvoice2"
-        ],
-        help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
+            "f5_tts",
+            "spark_tts",
+            "cosyvoice2"],
+        help="triton model_repo module name to request",
     )
 
     parser.add_argument(
@@ -80,6 +82,7 @@ def get_args():
     )
     return parser.parse_args()
 
+
 def prepare_request(
     waveform,
     reference_text,
@@ -97,7 +100,7 @@ def prepare_request(
                 1,
                 padding_duration
                 * sample_rate
-                * ((int(duration) // padding_duration) + 1),
+                * ((int(len(waveform) / sample_rate) // padding_duration) + 1),
             ),
             dtype=np.float32,
         )
@@ -105,11 +108,11 @@ def prepare_request(
         samples[0, : len(waveform)] = waveform
     else:
         samples = waveform
-        
+
     samples = samples.reshape(1, -1).astype(np.float32)
 
     data = {
-        "inputs":[
+        "inputs": [
             {
                 "name": "reference_wav",
                 "shape": samples.shape,
@@ -139,16 +142,17 @@ def prepare_request(
 
     return data
 
+
 if __name__ == "__main__":
     args = get_args()
     server_url = args.server_url
     if not server_url.startswith(("http://", "https://")):
         server_url = f"http://{server_url}"
-    
+
     url = f"{server_url}/v2/models/{args.model_name}/infer"
     waveform, sr = sf.read(args.reference_audio)
     assert sr == 16000, "sample rate hardcoded in server"
-    
+
     samples = np.array(waveform, dtype=np.float32)
     data = prepare_request(samples, args.reference_text, args.target_text)
 
@@ -166,4 +170,4 @@ if __name__ == "__main__":
         sample_rate = 16000
     else:
         sample_rate = 24000
-    sf.write(args.output_audio, audio, sample_rate, "PCM_16")
+    sf.write(args.output_audio, audio, sample_rate, "PCM_16")

+ 11 - 10
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py

@@ -35,33 +35,34 @@ import s3tokenizer
 
 ORIGINAL_VOCAB_SIZE = 151663
 
+
 class TritonPythonModel:
     """Triton Python model for audio tokenization.
-    
+
     This model takes reference audio input and extracts semantic tokens
     using s3tokenizer.
     """
 
     def initialize(self, args):
         """Initialize the model.
-        
+
         Args:
             args: Dictionary containing model configuration
         """
         # Parse model parameters
         parameters = json.loads(args['model_config'])['parameters']
         model_params = {k: v["string_value"] for k, v in parameters.items()}
-        
+
         self.device = torch.device("cuda")
         model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
         self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
 
     def execute(self, requests):
         """Execute inference on the batched requests.
-        
+
         Args:
             requests: List of inference requests
-            
+
         Returns:
             List of inference responses containing tokenized outputs
         """
@@ -79,18 +80,18 @@ class TritonPythonModel:
             # Prepare inputs
             wav = wav_array[:, :wav_len].squeeze(0)
             mels.append(s3tokenizer.log_mel_spectrogram(wav))
-            
+
         mels, mels_lens = s3tokenizer.padding(mels)
         codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
         codes = codes.clone() + ORIGINAL_VOCAB_SIZE
-        
+
         responses = []
         for i in range(len(requests)):
-            prompt_speech_tokens = codes[i, :codes_lens[i].item()]            
+            prompt_speech_tokens = codes[i, :codes_lens[i].item()]
             prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
                 "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
             inference_response = pb_utils.InferenceResponse(
                 output_tensors=[prompt_speech_tokens_tensor])
             responses.append(inference_response)
-                             
-        return responses
+
+        return responses

+ 61 - 45
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -42,16 +42,17 @@ import onnxruntime
 
 from matcha.utils.audio import mel_spectrogram
 
+
 class TritonPythonModel:
     """Triton Python model for Spark TTS.
-    
+
     This model orchestrates the end-to-end TTS pipeline by coordinating
     between audio tokenizer, LLM, and vocoder components.
     """
-    
+
     def initialize(self, args):
         """Initialize the model.
-        
+
         Args:
             args: Dictionary containing model configuration
         """
@@ -116,58 +117,58 @@ class TritonPythonModel:
             "input_ids": input_ids,
             "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
         }
-        
+
         # Convert inputs to Triton tensors
         input_tensor_list = [
             pb_utils.Tensor(k, v) for k, v in input_dict.items()
         ]
-        
+
         # Create and execute inference request
         llm_request = pb_utils.InferenceRequest(
             model_name="tensorrt_llm",
             requested_output_names=["output_ids", "sequence_length"],
             inputs=input_tensor_list,
         )
-        
+
         llm_responses = llm_request.exec(decoupled=self.decoupled)
         if self.decoupled:
             for llm_response in llm_responses:
                 if llm_response.has_error():
                     raise pb_utils.TritonModelException(llm_response.error().message())
-                
+
                 # Extract and process output
                 output_ids = pb_utils.get_output_tensor_by_name(
                     llm_response, "output_ids").as_numpy()
                 seq_lens = pb_utils.get_output_tensor_by_name(
                     llm_response, "sequence_length").as_numpy()
-                
+
                 # Get actual output IDs up to the sequence length
                 actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
-                
+
                 yield actual_output_ids
         else:
             llm_response = llm_responses
             if llm_response.has_error():
                 raise pb_utils.TritonModelException(llm_response.error().message())
-            
+
             # Extract and process output
             output_ids = pb_utils.get_output_tensor_by_name(
                 llm_response, "output_ids").as_numpy()
             seq_lens = pb_utils.get_output_tensor_by_name(
                 llm_response, "sequence_length").as_numpy()
-            
+
             # Get actual output IDs up to the sequence length
             actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
-            
-            yield actual_output_ids    
-                
+
+            yield actual_output_ids
+
     def forward_audio_tokenizer(self, wav, wav_len):
         """Forward pass through the audio tokenizer component.
-        
+
         Args:
             wav: Input waveform tensor
             wav_len: Waveform length tensor
-            
+
         Returns:
             Tuple of global and semantic tokens
         """
@@ -176,26 +177,31 @@ class TritonPythonModel:
             requested_output_names=['prompt_speech_tokens'],
             inputs=[wav, wav_len]
         )
-        
+
         inference_response = inference_request.exec()
         if inference_response.has_error():
             raise pb_utils.TritonModelException(inference_response.error().message())
-        
+
         # Extract and convert output tensors
         prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
         prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
 
         return prompt_speech_tokens
 
-    def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor:
+    def forward_token2wav(
+            self,
+            prompt_speech_tokens: torch.Tensor,
+            prompt_speech_feat: torch.Tensor,
+            prompt_spk_embedding: torch.Tensor,
+            target_speech_tokens: torch.Tensor) -> torch.Tensor:
         """Forward pass through the vocoder component.
-        
+
         Args:
             prompt_speech_tokens: Prompt speech tokens tensor
             prompt_speech_feat: Prompt speech feat tensor
             prompt_spk_embedding: Prompt spk embedding tensor
             target_speech_tokens: Target speech tokens tensor
-            
+
         Returns:
             Generated waveform tensor
         """
@@ -203,22 +209,22 @@ class TritonPythonModel:
         prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
         prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
-        
+
         # Create and execute inference request
         inference_request = pb_utils.InferenceRequest(
             model_name='token2wav',
             requested_output_names=['waveform'],
             inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
         )
-        
+
         inference_response = inference_request.exec()
         if inference_response.has_error():
             raise pb_utils.TritonModelException(inference_response.error().message())
-        
+
         # Extract and convert output waveform
         waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
         waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
-        
+
         return waveform
 
     def parse_input(self, text, prompt_text, prompt_speech_tokens):
@@ -231,43 +237,53 @@ class TritonPythonModel:
 
     def _extract_spk_embedding(self, speech):
         feat = kaldi.fbank(speech,
-                            num_mel_bins=80,
-                            dither=0,
-                            sample_frequency=16000)
+                           num_mel_bins=80,
+                           dither=0,
+                           sample_frequency=16000)
         feat = feat - feat.mean(dim=0, keepdim=True)
         embedding = self.campplus_session.run(None,
-                                                {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+                                              {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
         embedding = torch.tensor([embedding]).to(self.device).half()
         return embedding
 
-
     def _extract_speech_feat(self, speech):
-        speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device)
+        speech_feat = mel_spectrogram(
+            speech,
+            n_fft=1920,
+            num_mels=80,
+            sampling_rate=24000,
+            hop_size=480,
+            win_size=1920,
+            fmin=0,
+            fmax=8000).squeeze(
+            dim=0).transpose(
+            0,
+            1).to(
+                self.device)
         speech_feat = speech_feat.unsqueeze(dim=0)
         return speech_feat
 
     def execute(self, requests):
         """Execute inference on the batched requests.
-        
+
         Args:
             requests: List of inference requests
-            
+
         Returns:
             List of inference responses containing generated audio
         """
         responses = []
-        
+
         for request in requests:
             # Extract input tensors
             wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
             wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
-            
+
             # Process reference audio through audio tokenizer
 
             prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
             prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
 
-
             wav_tensor = wav.as_numpy()
             wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
             prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
@@ -275,20 +291,20 @@ class TritonPythonModel:
             token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
             prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
             prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
-            
+
             reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
             reference_text = reference_text[0][0].decode('utf-8')
-            
+
             target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
             target_text = target_text[0][0].decode('utf-8')
-            
+
             # Prepare prompt for LLM
             input_ids = self.parse_input(
                 text=target_text,
                 prompt_text=reference_text,
                 prompt_speech_tokens=prompt_speech_tokens,
             )
-            
+
             # Generate semantic tokens with LLM
             generated_ids_iter = self.forward_llm(input_ids)
 
@@ -305,13 +321,13 @@ class TritonPythonModel:
                 generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
                 prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
                 audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
-                
+
                 # Prepare response
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 response_sender.send(inference_response)
                 response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
-                self.logger.log_info(f"send tritonserver_response_complete_final to end")
+                self.logger.log_info("send tritonserver_response_complete_final to end")
             else:
                 generated_ids = next(generated_ids_iter)
                 generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
@@ -320,11 +336,11 @@ class TritonPythonModel:
 
                 prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
                 audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
-                
+
                 # Prepare response
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 responses.append(inference_response)
-            
+
         if not self.decoupled:
-            return responses
+            return responses

+ 11 - 13
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
 
 ORIGINAL_VOCAB_SIZE = 151663
 
+
 class CosyVoice2:
 
     def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
@@ -66,6 +67,7 @@ class CosyVoice2:
                                 trt_concurrent,
                                 self.fp16)
 
+
 class CosyVoice2Model:
 
     def __init__(self,
@@ -109,16 +111,17 @@ class CosyVoice2Model:
         input_names = ["x", "mask", "mu", "cond"]
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
 
+
 class TritonPythonModel:
     """Triton Python model for vocoder.
-    
+
     This model takes global and semantic tokens as input and generates audio waveforms
     using the BiCodec vocoder.
     """
 
     def initialize(self, args):
         """Initialize the model.
-        
+
         Args:
             args: Dictionary containing model configuration
         """
@@ -126,24 +129,23 @@ class TritonPythonModel:
         parameters = json.loads(args['model_config'])['parameters']
         model_params = {key: value["string_value"] for key, value in parameters.items()}
         model_dir = model_params["model_dir"]
-        
+
         # Initialize device and vocoder
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
-        
+
         self.token2wav_model = CosyVoice2(
             model_dir, load_jit=True, load_trt=True, fp16=True
         )
 
         logger.info("Token2Wav initialized successfully")
 
-
     def execute(self, requests):
         """Execute inference on the batched requests.
-        
+
         Args:
             requests: List of inference requests
-            
+
         Returns:
             List of inference responses containing generated waveforms
         """
@@ -163,7 +165,7 @@ class TritonPythonModel:
             # shift the speech tokens according to the original vocab size
             prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
             target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
-            
+
             tts_mel, _ = self.token2wav_model.model.flow.inference(
                 token=target_speech_tokens,
                 token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
@@ -189,9 +191,5 @@ class TritonPythonModel:
             wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
             inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
             responses.append(inference_response)
-                             
-        return responses
-
-
-
 
+        return responses

+ 14 - 26
runtime/triton_trtllm/scripts/convert_checkpoint.py

@@ -35,8 +35,7 @@ def parse_arguments():
         type=str,
         default='auto',
         choices=['auto', 'float16', 'bfloat16', 'float32'],
-        help=
-        "The data type for the model weights and activations if not quantized. "
+        help="The data type for the model weights and activations if not quantized. "
         "If 'auto', the data type is automatically inferred from the source model; "
         "however, if the source dtype is float32, it is converted to float16.")
     parser.add_argument(
@@ -49,8 +48,7 @@ def parse_arguments():
         '--disable_weight_only_quant_plugin',
         default=False,
         action="store_true",
-        help=
-        'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
+        help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
         'You must also use --use_weight_only for that argument to have an impact.'
     )
     parser.add_argument(
@@ -60,16 +58,14 @@ def parse_arguments():
         nargs='?',
         default='int8',
         choices=['int8', 'int4', 'int4_gptq'],
-        help=
-        'Define the precision for the weights when using weight-only quantization.'
+        help='Define the precision for the weights when using weight-only quantization.'
         'You must also use --use_weight_only for that argument to have an impact.'
     )
     parser.add_argument(
         '--calib_dataset',
         type=str,
         default='ccdv/cnn_dailymail',
-        help=
-        "The huggingface dataset name or the local directory of the dataset for calibration."
+        help="The huggingface dataset name or the local directory of the dataset for calibration."
     )
     parser.add_argument(
         "--smoothquant",
@@ -83,31 +79,27 @@ def parse_arguments():
         '--per_channel',
         action="store_true",
         default=False,
-        help=
-        'By default, we use a single static scaling factor for the GEMM\'s result. '
+        help='By default, we use a single static scaling factor for the GEMM\'s result. '
         'per_channel instead uses a different static scaling factor for each channel. '
         'The latter is usually more accurate, but a little slower.')
     parser.add_argument(
         '--per_token',
         action="store_true",
         default=False,
-        help=
-        'By default, we use a single static scaling factor to scale activations in the int8 range. '
+        help='By default, we use a single static scaling factor to scale activations in the int8 range. '
         'per_token chooses at run time, and for each token, a custom scaling factor. '
         'The latter is usually more accurate, but a little slower.')
     parser.add_argument(
         '--int8_kv_cache',
         default=False,
         action="store_true",
-        help=
-        'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
+        help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
     )
     parser.add_argument(
         '--per_group',
         default=False,
         action="store_true",
-        help=
-        'By default, we use a single static scaling factor to scale weights in the int4 range. '
+        help='By default, we use a single static scaling factor to scale weights in the int4 range. '
         'per_group chooses at run time, and for each group, a custom scaling factor. '
         'The flag is built for GPTQ/AWQ quantization.')
 
@@ -121,16 +113,14 @@ def parse_arguments():
         '--use_parallel_embedding',
         action="store_true",
         default=False,
-        help=
-        'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
+        help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
     )
     parser.add_argument(
         '--embedding_sharding_dim',
         type=int,
         default=0,
         choices=[0, 1],
-        help=
-        'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
+        help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
         'To shard it along hidden dimension, set embedding_sharding_dim=1'
         'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
     )
@@ -147,15 +137,13 @@ def parse_arguments():
         '--moe_tp_size',
         type=int,
         default=-1,
-        help=
-        'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
+        help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
     )
     parser.add_argument(
         '--moe_ep_size',
         type=int,
         default=-1,
-        help=
-        'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
+        help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
     )
     args = parser.parse_args()
     return args
@@ -249,7 +237,7 @@ def convert_and_save_hf(args):
                                                trust_remote_code=True)
         quant_config, override_fields = update_quant_config_from_hf(
             quant_config, hf_config, override_fields)
-    except:
+    except BaseException:
         logger.warning("AutoConfig cannot load the huggingface config.")
 
     if args.smoothquant is not None or args.int8_kv_cache:
@@ -339,4 +327,4 @@ def main():
 
 
 if __name__ == '__main__':
-    main()
+    main()

+ 2 - 3
runtime/triton_trtllm/scripts/fill_template.py

@@ -1,4 +1,4 @@
-#! /usr/bin/env python3
+# /usr/bin/env python3
 from argparse import ArgumentParser
 from string import Template
 
@@ -59,8 +59,7 @@ if __name__ == "__main__":
     parser.add_argument("file_path", help="path of the .pbtxt to modify")
     parser.add_argument(
         "substitutions",
-        help=
-        "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
+        help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
     )
     parser.add_argument("--in_place",
                         "-i",

+ 1 - 2
runtime/triton_trtllm/scripts/test_llm.py

@@ -46,7 +46,6 @@ def parse_arguments(args=None):
     parser.add_argument('--top_k', type=int, default=50)
     parser.add_argument('--top_p', type=float, default=0.95)
 
-
     return parser.parse_args(args=args)
 
 
@@ -60,7 +59,7 @@ def parse_input(tokenizer,
         input_ids = tokenizer.encode(
             curr_text)
         batch_input_ids.append(input_ids)
- 
+
     batch_input_ids = [
         torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
     ]