root 4 months ago
parent
commit
07cbc51cd1

+ 52 - 49
runtime/triton_trtllm/client_grpc.py

@@ -1,4 +1,3 @@
-#!/usr/bin/env python3
 # Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
 # Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
 #                2023  Nvidia              (authors: Yuekai Zhang)
 #                2023  Nvidia              (authors: Yuekai Zhang)
 #                2023  Recurrent.ai        (authors: Songtao Shi)
 #                2023  Recurrent.ai        (authors: Songtao Shi)
@@ -46,7 +45,7 @@ import asyncio
 import json
 import json
 import queue  # Added
 import queue  # Added
 import uuid  # Added
 import uuid  # Added
-import functools # Added
+import functools  # Added
 
 
 import os
 import os
 import time
 import time
@@ -56,9 +55,9 @@ from pathlib import Path
 import numpy as np
 import numpy as np
 import soundfile as sf
 import soundfile as sf
 import tritonclient
 import tritonclient
-import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
-import tritonclient.grpc as grpcclient_sync # Added sync client import
-from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
+import tritonclient.grpc.aio as grpcclient_aio  # Renamed original import
+import tritonclient.grpc as grpcclient_sync  # Added sync client import
+from tritonclient.utils import np_to_triton_dtype, InferenceServerException  # Added InferenceServerException
 
 
 
 
 # --- Added UserData and callback ---
 # --- Added UserData and callback ---
@@ -76,9 +75,10 @@ class UserData:
             return self._first_chunk_time - self._start_time
             return self._first_chunk_time - self._start_time
         return None
         return None
 
 
+
 def callback(user_data, result, error):
 def callback(user_data, result, error):
     if user_data._first_chunk_time is None and not error:
     if user_data._first_chunk_time is None and not error:
-        user_data._first_chunk_time = time.time() # Record time of first successful chunk
+        user_data._first_chunk_time = time.time()  # Record time of first successful chunk
     if error:
     if error:
         user_data._completed_requests.put(error)
         user_data._completed_requests.put(error)
     else:
     else:
@@ -206,8 +206,11 @@ def get_args():
         "--model-name",
         "--model-name",
         type=str,
         type=str,
         default="f5_tts",
         default="f5_tts",
-        choices=["f5_tts", "spark_tts", "cosyvoice2"],
-        help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
+        choices=[
+            "f5_tts",
+            "spark_tts",
+            "cosyvoice2"],
+        help="triton model_repo module name to request",
     )
     )
 
 
     parser.add_argument(
     parser.add_argument(
@@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000):
         waveform = resample(waveform, num_samples)
         waveform = resample(waveform, num_samples)
     return waveform, target_sample_rate
     return waveform, target_sample_rate
 
 
+
 def prepare_request_input_output(
 def prepare_request_input_output(
-    protocol_client, # Can be grpcclient_aio or grpcclient_sync
+    protocol_client,  # Can be grpcclient_aio or grpcclient_sync
     waveform,
     waveform,
     reference_text,
     reference_text,
     target_text,
     target_text,
     sample_rate=16000,
     sample_rate=16000,
-    padding_duration: int = None # Optional padding for offline mode
+    padding_duration: int = None  # Optional padding for offline mode
 ):
 ):
     """Prepares inputs for Triton inference (offline or streaming)."""
     """Prepares inputs for Triton inference (offline or streaming)."""
     assert len(waveform.shape) == 1, "waveform should be 1D"
     assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -291,9 +295,9 @@ def prepare_request_input_output(
         # Estimate target duration based on text length ratio (crude estimation)
         # Estimate target duration based on text length ratio (crude estimation)
         # Avoid division by zero if reference_text is empty
         # Avoid division by zero if reference_text is empty
         if reference_text:
         if reference_text:
-             estimated_target_duration = duration / len(reference_text) * len(target_text)
+            estimated_target_duration = duration / len(reference_text) * len(target_text)
         else:
         else:
-             estimated_target_duration = duration # Assume target duration similar to reference if no text
+            estimated_target_duration = duration  # Assume target duration similar to reference if no text
 
 
         # Calculate required samples based on estimated total duration
         # Calculate required samples based on estimated total duration
         required_total_samples = padding_duration * sample_rate * (
         required_total_samples = padding_duration * sample_rate * (
@@ -329,6 +333,7 @@ def prepare_request_input_output(
 
 
     return inputs, outputs
     return inputs, outputs
 
 
+
 def run_sync_streaming_inference(
 def run_sync_streaming_inference(
     sync_triton_client: tritonclient.grpc.InferenceServerClient,
     sync_triton_client: tritonclient.grpc.InferenceServerClient,
     model_name: str,
     model_name: str,
@@ -342,7 +347,7 @@ def run_sync_streaming_inference(
 ):
 ):
     """Helper function to run the blocking sync streaming call."""
     """Helper function to run the blocking sync streaming call."""
     start_time_total = time.time()
     start_time_total = time.time()
-    user_data.record_start_time() # Record start time for first chunk latency calculation
+    user_data.record_start_time()  # Record start time for first chunk latency calculation
 
 
     # Establish stream
     # Establish stream
     sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
     sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
@@ -360,11 +365,11 @@ def run_sync_streaming_inference(
     audios = []
     audios = []
     while True:
     while True:
         try:
         try:
-            result = user_data._completed_requests.get() # Add timeout
+            result = user_data._completed_requests.get()  # Add timeout
             if isinstance(result, InferenceServerException):
             if isinstance(result, InferenceServerException):
                 print(f"Received InferenceServerException: {result}")
                 print(f"Received InferenceServerException: {result}")
                 sync_triton_client.stop_stream()
                 sync_triton_client.stop_stream()
-                return None, None, None # Indicate error
+                return None, None, None  # Indicate error
             # Get response metadata
             # Get response metadata
             response = result.get_response()
             response = result.get_response()
             final = response.parameters["triton_final_response"].bool_param
             final = response.parameters["triton_final_response"].bool_param
@@ -372,15 +377,15 @@ def run_sync_streaming_inference(
                 break
                 break
 
 
             audio_chunk = result.as_numpy("waveform").reshape(-1)
             audio_chunk = result.as_numpy("waveform").reshape(-1)
-            if audio_chunk.size > 0: # Only append non-empty chunks
-                 audios.append(audio_chunk)
+            if audio_chunk.size > 0:  # Only append non-empty chunks
+                audios.append(audio_chunk)
             else:
             else:
                 print("Warning: received empty audio chunk.")
                 print("Warning: received empty audio chunk.")
 
 
         except queue.Empty:
         except queue.Empty:
             print(f"Timeout waiting for response for request id {request_id}")
             print(f"Timeout waiting for response for request id {request_id}")
             sync_triton_client.stop_stream()
             sync_triton_client.stop_stream()
-            return None, None, None # Indicate error
+            return None, None, None  # Indicate error
 
 
     sync_triton_client.stop_stream()
     sync_triton_client.stop_stream()
     end_time_total = time.time()
     end_time_total = time.time()
@@ -398,19 +403,19 @@ def run_sync_streaming_inference(
         # Simplified reconstruction based on client_grpc_streaming.py
         # Simplified reconstruction based on client_grpc_streaming.py
         if not audios:
         if not audios:
             print("Warning: No audio chunks received.")
             print("Warning: No audio chunks received.")
-            reconstructed_audio = np.array([], dtype=np.float32) # Empty array
+            reconstructed_audio = np.array([], dtype=np.float32)  # Empty array
         elif len(audios) == 1:
         elif len(audios) == 1:
             reconstructed_audio = audios[0]
             reconstructed_audio = audios[0]
         else:
         else:
-            reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
+            reconstructed_audio = audios[0][:-cross_fade_samples]  # Start with first chunk minus overlap
             for i in range(1, len(audios)):
             for i in range(1, len(audios)):
-                 # Cross-fade section
-                 cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
-                                        audios[i - 1][-cross_fade_samples:] * fade_out)
-                 # Middle section of the current chunk
-                 middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
-                 # Concatenate
-                 reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
+                # Cross-fade section
+                cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
+                                       audios[i - 1][-cross_fade_samples:] * fade_out)
+                # Middle section of the current chunk
+                middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
+                # Concatenate
+                reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
             # Add the last part of the final chunk
             # Add the last part of the final chunk
             reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
             reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
 
 
@@ -421,11 +426,11 @@ def run_sync_streaming_inference(
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
             sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
         else:
         else:
             print("Warning: No audio chunks received or reconstructed.")
             print("Warning: No audio chunks received or reconstructed.")
-            actual_duration = 0 # Set duration to 0 if no audio
+            actual_duration = 0  # Set duration to 0 if no audio
 
 
     else:
     else:
-         print("Warning: No audio chunks received.")
-         actual_duration = 0
+        print("Warning: No audio chunks received.")
+        actual_duration = 0
 
 
     return total_request_latency, first_chunk_latency, actual_duration
     return total_request_latency, first_chunk_latency, actual_duration
 
 
@@ -433,7 +438,7 @@ def run_sync_streaming_inference(
 async def send_streaming(
 async def send_streaming(
     manifest_item_list: list,
     manifest_item_list: list,
     name: str,
     name: str,
-    server_url: str, # Changed from sync_triton_client
+    server_url: str,  # Changed from sync_triton_client
     protocol_client: types.ModuleType,
     protocol_client: types.ModuleType,
     log_interval: int,
     log_interval: int,
     model_name: str,
     model_name: str,
@@ -445,11 +450,11 @@ async def send_streaming(
     total_duration = 0.0
     total_duration = 0.0
     latency_data = []
     latency_data = []
     task_id = int(name[5:])
     task_id = int(name[5:])
-    sync_triton_client = None # Initialize client variable
+    sync_triton_client = None  # Initialize client variable
 
 
-    try: # Wrap in try...finally to ensure client closing
+    try:  # Wrap in try...finally to ensure client closing
         print(f"{name}: Initializing sync client for streaming...")
         print(f"{name}: Initializing sync client for streaming...")
-        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
+        sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)  # Create client here
 
 
         print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
         print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
         for i, item in enumerate(manifest_item_list):
         for i, item in enumerate(manifest_item_list):
@@ -491,8 +496,7 @@ async def send_streaming(
                     latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
                     latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
                     total_duration += actual_duration
                     total_duration += actual_duration
                 else:
                 else:
-                     print(f"{name}: Item {i} failed.")
-
+                    print(f"{name}: Item {i} failed.")
 
 
             except FileNotFoundError:
             except FileNotFoundError:
                 print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
                 print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
@@ -501,8 +505,7 @@ async def send_streaming(
                 import traceback
                 import traceback
                 traceback.print_exc()
                 traceback.print_exc()
 
 
-
-    finally: # Ensure client is closed
+    finally:  # Ensure client is closed
         if sync_triton_client:
         if sync_triton_client:
             try:
             try:
                 print(f"{name}: Closing sync client...")
                 print(f"{name}: Closing sync client...")
@@ -510,10 +513,10 @@ async def send_streaming(
             except Exception as e:
             except Exception as e:
                 print(f"{name}: Error closing sync client: {e}")
                 print(f"{name}: Error closing sync client: {e}")
 
 
-
     print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
     print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
     return total_duration, latency_data
     return total_duration, latency_data
 
 
+
 async def send(
 async def send(
     manifest_item_list: list,
     manifest_item_list: list,
     name: str,
     name: str,
@@ -605,6 +608,7 @@ def split_data(data, k):
 
 
     return result
     return result
 
 
+
 async def main():
 async def main():
     args = get_args()
     args = get_args()
     url = f"{args.server_addr}:{args.server_port}"
     url = f"{args.server_addr}:{args.server_port}"
@@ -622,7 +626,7 @@ async def main():
         # Use the sync client for streaming tasks, handled via asyncio.to_thread
         # Use the sync client for streaming tasks, handled via asyncio.to_thread
         # We will create one sync client instance PER TASK inside send_streaming.
         # We will create one sync client instance PER TASK inside send_streaming.
         # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
         # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
-        protocol_client = grpcclient_sync # protocol client for input prep
+        protocol_client = grpcclient_sync  # protocol client for input prep
     else:
     else:
         raise ValueError(f"Invalid mode: {args.mode}")
         raise ValueError(f"Invalid mode: {args.mode}")
     # --- End Client Initialization ---
     # --- End Client Initialization ---
@@ -682,11 +686,11 @@ async def main():
                 )
                 )
             )
             )
         elif args.mode == "streaming":
         elif args.mode == "streaming":
-             task = asyncio.create_task(
+            task = asyncio.create_task(
                 send_streaming(
                 send_streaming(
                     manifest_item_list[i],
                     manifest_item_list[i],
                     name=f"task-{i}",
                     name=f"task-{i}",
-                    server_url=url, # Pass URL instead of client
+                    server_url=url,  # Pass URL instead of client
                     protocol_client=protocol_client,
                     protocol_client=protocol_client,
                     log_interval=args.log_interval,
                     log_interval=args.log_interval,
                     model_name=args.model_name,
                     model_name=args.model_name,
@@ -709,16 +713,15 @@ async def main():
     for ans in ans_list:
     for ans in ans_list:
         if ans:
         if ans:
             total_duration += ans[0]
             total_duration += ans[0]
-            latency_data.extend(ans[1]) # Use extend for list of lists
+            latency_data.extend(ans[1])  # Use extend for list of lists
         else:
         else:
-             print("Warning: A task returned None, possibly due to an error.")
-
+            print("Warning: A task returned None, possibly due to an error.")
 
 
     if total_duration == 0:
     if total_duration == 0:
         print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
         print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
         rtf = float('inf')
         rtf = float('inf')
     else:
     else:
-         rtf = elapsed / total_duration
+        rtf = elapsed / total_duration
 
 
     s = f"Mode: {args.mode}\n"
     s = f"Mode: {args.mode}\n"
     s += f"RTF: {rtf:.4f}\n"
     s += f"RTF: {rtf:.4f}\n"
@@ -759,7 +762,7 @@ async def main():
                 s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
                 s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
                 s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
                 s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
             else:
             else:
-                 s += "No total request latency data collected.\n"
+                s += "No total request latency data collected.\n"
 
 
             s += "\n--- First Chunk Latency ---\n"
             s += "\n--- First Chunk Latency ---\n"
             if first_chunk_latency_list:
             if first_chunk_latency_list:
@@ -772,7 +775,7 @@ async def main():
                 s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
                 s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
                 s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
                 s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
             else:
             else:
-                 s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
+                s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
     else:
     else:
         s += "No latency data collected.\n"
         s += "No latency data collected.\n"
     # --- End Statistics Reporting ---
     # --- End Statistics Reporting ---
@@ -785,7 +788,7 @@ async def main():
     elif args.reference_audio:
     elif args.reference_audio:
         name = Path(args.reference_audio).stem
         name = Path(args.reference_audio).stem
     else:
     else:
-        name = "results" # Default name if no manifest/split/audio provided
+        name = "results"  # Default name if no manifest/split/audio provided
     with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
     with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
         f.write(s)
         f.write(s)
 
 

+ 13 - 9
runtime/triton_trtllm/client_http.py

@@ -29,6 +29,7 @@ import json
 import numpy as np
 import numpy as np
 import argparse
 import argparse
 
 
+
 def get_args():
 def get_args():
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
         formatter_class=argparse.ArgumentDefaultsHelpFormatter
         formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -67,9 +68,10 @@ def get_args():
         type=str,
         type=str,
         default="spark_tts",
         default="spark_tts",
         choices=[
         choices=[
-            "f5_tts", "spark_tts", "cosyvoice2"
-        ],
-        help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
+            "f5_tts",
+            "spark_tts",
+            "cosyvoice2"],
+        help="triton model_repo module name to request",
     )
     )
 
 
     parser.add_argument(
     parser.add_argument(
@@ -80,6 +82,7 @@ def get_args():
     )
     )
     return parser.parse_args()
     return parser.parse_args()
 
 
+
 def prepare_request(
 def prepare_request(
     waveform,
     waveform,
     reference_text,
     reference_text,
@@ -97,7 +100,7 @@ def prepare_request(
                 1,
                 1,
                 padding_duration
                 padding_duration
                 * sample_rate
                 * sample_rate
-                * ((int(duration) // padding_duration) + 1),
+                * ((int(len(waveform) / sample_rate) // padding_duration) + 1),
             ),
             ),
             dtype=np.float32,
             dtype=np.float32,
         )
         )
@@ -105,11 +108,11 @@ def prepare_request(
         samples[0, : len(waveform)] = waveform
         samples[0, : len(waveform)] = waveform
     else:
     else:
         samples = waveform
         samples = waveform
-        
+
     samples = samples.reshape(1, -1).astype(np.float32)
     samples = samples.reshape(1, -1).astype(np.float32)
 
 
     data = {
     data = {
-        "inputs":[
+        "inputs": [
             {
             {
                 "name": "reference_wav",
                 "name": "reference_wav",
                 "shape": samples.shape,
                 "shape": samples.shape,
@@ -139,16 +142,17 @@ def prepare_request(
 
 
     return data
     return data
 
 
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     args = get_args()
     args = get_args()
     server_url = args.server_url
     server_url = args.server_url
     if not server_url.startswith(("http://", "https://")):
     if not server_url.startswith(("http://", "https://")):
         server_url = f"http://{server_url}"
         server_url = f"http://{server_url}"
-    
+
     url = f"{server_url}/v2/models/{args.model_name}/infer"
     url = f"{server_url}/v2/models/{args.model_name}/infer"
     waveform, sr = sf.read(args.reference_audio)
     waveform, sr = sf.read(args.reference_audio)
     assert sr == 16000, "sample rate hardcoded in server"
     assert sr == 16000, "sample rate hardcoded in server"
-    
+
     samples = np.array(waveform, dtype=np.float32)
     samples = np.array(waveform, dtype=np.float32)
     data = prepare_request(samples, args.reference_text, args.target_text)
     data = prepare_request(samples, args.reference_text, args.target_text)
 
 
@@ -166,4 +170,4 @@ if __name__ == "__main__":
         sample_rate = 16000
         sample_rate = 16000
     else:
     else:
         sample_rate = 24000
         sample_rate = 24000
-    sf.write(args.output_audio, audio, sample_rate, "PCM_16")
+    sf.write(args.output_audio, audio, sample_rate, "PCM_16")

+ 11 - 10
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py

@@ -35,33 +35,34 @@ import s3tokenizer
 
 
 ORIGINAL_VOCAB_SIZE = 151663
 ORIGINAL_VOCAB_SIZE = 151663
 
 
+
 class TritonPythonModel:
 class TritonPythonModel:
     """Triton Python model for audio tokenization.
     """Triton Python model for audio tokenization.
-    
+
     This model takes reference audio input and extracts semantic tokens
     This model takes reference audio input and extracts semantic tokens
     using s3tokenizer.
     using s3tokenizer.
     """
     """
 
 
     def initialize(self, args):
     def initialize(self, args):
         """Initialize the model.
         """Initialize the model.
-        
+
         Args:
         Args:
             args: Dictionary containing model configuration
             args: Dictionary containing model configuration
         """
         """
         # Parse model parameters
         # Parse model parameters
         parameters = json.loads(args['model_config'])['parameters']
         parameters = json.loads(args['model_config'])['parameters']
         model_params = {k: v["string_value"] for k, v in parameters.items()}
         model_params = {k: v["string_value"] for k, v in parameters.items()}
-        
+
         self.device = torch.device("cuda")
         self.device = torch.device("cuda")
         model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
         model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
         self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
         self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
 
 
     def execute(self, requests):
     def execute(self, requests):
         """Execute inference on the batched requests.
         """Execute inference on the batched requests.
-        
+
         Args:
         Args:
             requests: List of inference requests
             requests: List of inference requests
-            
+
         Returns:
         Returns:
             List of inference responses containing tokenized outputs
             List of inference responses containing tokenized outputs
         """
         """
@@ -79,18 +80,18 @@ class TritonPythonModel:
             # Prepare inputs
             # Prepare inputs
             wav = wav_array[:, :wav_len].squeeze(0)
             wav = wav_array[:, :wav_len].squeeze(0)
             mels.append(s3tokenizer.log_mel_spectrogram(wav))
             mels.append(s3tokenizer.log_mel_spectrogram(wav))
-            
+
         mels, mels_lens = s3tokenizer.padding(mels)
         mels, mels_lens = s3tokenizer.padding(mels)
         codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
         codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
         codes = codes.clone() + ORIGINAL_VOCAB_SIZE
         codes = codes.clone() + ORIGINAL_VOCAB_SIZE
-        
+
         responses = []
         responses = []
         for i in range(len(requests)):
         for i in range(len(requests)):
-            prompt_speech_tokens = codes[i, :codes_lens[i].item()]            
+            prompt_speech_tokens = codes[i, :codes_lens[i].item()]
             prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
             prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
                 "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
                 "prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
             inference_response = pb_utils.InferenceResponse(
             inference_response = pb_utils.InferenceResponse(
                 output_tensors=[prompt_speech_tokens_tensor])
                 output_tensors=[prompt_speech_tokens_tensor])
             responses.append(inference_response)
             responses.append(inference_response)
-                             
-        return responses
+
+        return responses

+ 61 - 45
runtime/triton_trtllm/model_repo/cosyvoice2/1/model.py

@@ -42,16 +42,17 @@ import onnxruntime
 
 
 from matcha.utils.audio import mel_spectrogram
 from matcha.utils.audio import mel_spectrogram
 
 
+
 class TritonPythonModel:
 class TritonPythonModel:
     """Triton Python model for Spark TTS.
     """Triton Python model for Spark TTS.
-    
+
     This model orchestrates the end-to-end TTS pipeline by coordinating
     This model orchestrates the end-to-end TTS pipeline by coordinating
     between audio tokenizer, LLM, and vocoder components.
     between audio tokenizer, LLM, and vocoder components.
     """
     """
-    
+
     def initialize(self, args):
     def initialize(self, args):
         """Initialize the model.
         """Initialize the model.
-        
+
         Args:
         Args:
             args: Dictionary containing model configuration
             args: Dictionary containing model configuration
         """
         """
@@ -116,58 +117,58 @@ class TritonPythonModel:
             "input_ids": input_ids,
             "input_ids": input_ids,
             "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
             "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
         }
         }
-        
+
         # Convert inputs to Triton tensors
         # Convert inputs to Triton tensors
         input_tensor_list = [
         input_tensor_list = [
             pb_utils.Tensor(k, v) for k, v in input_dict.items()
             pb_utils.Tensor(k, v) for k, v in input_dict.items()
         ]
         ]
-        
+
         # Create and execute inference request
         # Create and execute inference request
         llm_request = pb_utils.InferenceRequest(
         llm_request = pb_utils.InferenceRequest(
             model_name="tensorrt_llm",
             model_name="tensorrt_llm",
             requested_output_names=["output_ids", "sequence_length"],
             requested_output_names=["output_ids", "sequence_length"],
             inputs=input_tensor_list,
             inputs=input_tensor_list,
         )
         )
-        
+
         llm_responses = llm_request.exec(decoupled=self.decoupled)
         llm_responses = llm_request.exec(decoupled=self.decoupled)
         if self.decoupled:
         if self.decoupled:
             for llm_response in llm_responses:
             for llm_response in llm_responses:
                 if llm_response.has_error():
                 if llm_response.has_error():
                     raise pb_utils.TritonModelException(llm_response.error().message())
                     raise pb_utils.TritonModelException(llm_response.error().message())
-                
+
                 # Extract and process output
                 # Extract and process output
                 output_ids = pb_utils.get_output_tensor_by_name(
                 output_ids = pb_utils.get_output_tensor_by_name(
                     llm_response, "output_ids").as_numpy()
                     llm_response, "output_ids").as_numpy()
                 seq_lens = pb_utils.get_output_tensor_by_name(
                 seq_lens = pb_utils.get_output_tensor_by_name(
                     llm_response, "sequence_length").as_numpy()
                     llm_response, "sequence_length").as_numpy()
-                
+
                 # Get actual output IDs up to the sequence length
                 # Get actual output IDs up to the sequence length
                 actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
                 actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
-                
+
                 yield actual_output_ids
                 yield actual_output_ids
         else:
         else:
             llm_response = llm_responses
             llm_response = llm_responses
             if llm_response.has_error():
             if llm_response.has_error():
                 raise pb_utils.TritonModelException(llm_response.error().message())
                 raise pb_utils.TritonModelException(llm_response.error().message())
-            
+
             # Extract and process output
             # Extract and process output
             output_ids = pb_utils.get_output_tensor_by_name(
             output_ids = pb_utils.get_output_tensor_by_name(
                 llm_response, "output_ids").as_numpy()
                 llm_response, "output_ids").as_numpy()
             seq_lens = pb_utils.get_output_tensor_by_name(
             seq_lens = pb_utils.get_output_tensor_by_name(
                 llm_response, "sequence_length").as_numpy()
                 llm_response, "sequence_length").as_numpy()
-            
+
             # Get actual output IDs up to the sequence length
             # Get actual output IDs up to the sequence length
             actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
             actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
-            
-            yield actual_output_ids    
-                
+
+            yield actual_output_ids
+
     def forward_audio_tokenizer(self, wav, wav_len):
     def forward_audio_tokenizer(self, wav, wav_len):
         """Forward pass through the audio tokenizer component.
         """Forward pass through the audio tokenizer component.
-        
+
         Args:
         Args:
             wav: Input waveform tensor
             wav: Input waveform tensor
             wav_len: Waveform length tensor
             wav_len: Waveform length tensor
-            
+
         Returns:
         Returns:
             Tuple of global and semantic tokens
             Tuple of global and semantic tokens
         """
         """
@@ -176,26 +177,31 @@ class TritonPythonModel:
             requested_output_names=['prompt_speech_tokens'],
             requested_output_names=['prompt_speech_tokens'],
             inputs=[wav, wav_len]
             inputs=[wav, wav_len]
         )
         )
-        
+
         inference_response = inference_request.exec()
         inference_response = inference_request.exec()
         if inference_response.has_error():
         if inference_response.has_error():
             raise pb_utils.TritonModelException(inference_response.error().message())
             raise pb_utils.TritonModelException(inference_response.error().message())
-        
+
         # Extract and convert output tensors
         # Extract and convert output tensors
         prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
         prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
         prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
         prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
 
 
         return prompt_speech_tokens
         return prompt_speech_tokens
 
 
-    def forward_token2wav(self, prompt_speech_tokens: torch.Tensor, prompt_speech_feat: torch.Tensor, prompt_spk_embedding: torch.Tensor, target_speech_tokens: torch.Tensor) -> torch.Tensor:
+    def forward_token2wav(
+            self,
+            prompt_speech_tokens: torch.Tensor,
+            prompt_speech_feat: torch.Tensor,
+            prompt_spk_embedding: torch.Tensor,
+            target_speech_tokens: torch.Tensor) -> torch.Tensor:
         """Forward pass through the vocoder component.
         """Forward pass through the vocoder component.
-        
+
         Args:
         Args:
             prompt_speech_tokens: Prompt speech tokens tensor
             prompt_speech_tokens: Prompt speech tokens tensor
             prompt_speech_feat: Prompt speech feat tensor
             prompt_speech_feat: Prompt speech feat tensor
             prompt_spk_embedding: Prompt spk embedding tensor
             prompt_spk_embedding: Prompt spk embedding tensor
             target_speech_tokens: Target speech tokens tensor
             target_speech_tokens: Target speech tokens tensor
-            
+
         Returns:
         Returns:
             Generated waveform tensor
             Generated waveform tensor
         """
         """
@@ -203,22 +209,22 @@ class TritonPythonModel:
         prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
         prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
         prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
         prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
         target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
-        
+
         # Create and execute inference request
         # Create and execute inference request
         inference_request = pb_utils.InferenceRequest(
         inference_request = pb_utils.InferenceRequest(
             model_name='token2wav',
             model_name='token2wav',
             requested_output_names=['waveform'],
             requested_output_names=['waveform'],
             inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
             inputs=[prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor, target_speech_tokens_tensor]
         )
         )
-        
+
         inference_response = inference_request.exec()
         inference_response = inference_request.exec()
         if inference_response.has_error():
         if inference_response.has_error():
             raise pb_utils.TritonModelException(inference_response.error().message())
             raise pb_utils.TritonModelException(inference_response.error().message())
-        
+
         # Extract and convert output waveform
         # Extract and convert output waveform
         waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
         waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
         waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
         waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
-        
+
         return waveform
         return waveform
 
 
     def parse_input(self, text, prompt_text, prompt_speech_tokens):
     def parse_input(self, text, prompt_text, prompt_speech_tokens):
@@ -231,43 +237,53 @@ class TritonPythonModel:
 
 
     def _extract_spk_embedding(self, speech):
     def _extract_spk_embedding(self, speech):
         feat = kaldi.fbank(speech,
         feat = kaldi.fbank(speech,
-                            num_mel_bins=80,
-                            dither=0,
-                            sample_frequency=16000)
+                           num_mel_bins=80,
+                           dither=0,
+                           sample_frequency=16000)
         feat = feat - feat.mean(dim=0, keepdim=True)
         feat = feat - feat.mean(dim=0, keepdim=True)
         embedding = self.campplus_session.run(None,
         embedding = self.campplus_session.run(None,
-                                                {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+                                              {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
         embedding = torch.tensor([embedding]).to(self.device).half()
         embedding = torch.tensor([embedding]).to(self.device).half()
         return embedding
         return embedding
 
 
-
     def _extract_speech_feat(self, speech):
     def _extract_speech_feat(self, speech):
-        speech_feat = mel_spectrogram(speech, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=8000).squeeze(dim=0).transpose(0, 1).to(self.device)
+        speech_feat = mel_spectrogram(
+            speech,
+            n_fft=1920,
+            num_mels=80,
+            sampling_rate=24000,
+            hop_size=480,
+            win_size=1920,
+            fmin=0,
+            fmax=8000).squeeze(
+            dim=0).transpose(
+            0,
+            1).to(
+                self.device)
         speech_feat = speech_feat.unsqueeze(dim=0)
         speech_feat = speech_feat.unsqueeze(dim=0)
         return speech_feat
         return speech_feat
 
 
     def execute(self, requests):
     def execute(self, requests):
         """Execute inference on the batched requests.
         """Execute inference on the batched requests.
-        
+
         Args:
         Args:
             requests: List of inference requests
             requests: List of inference requests
-            
+
         Returns:
         Returns:
             List of inference responses containing generated audio
             List of inference responses containing generated audio
         """
         """
         responses = []
         responses = []
-        
+
         for request in requests:
         for request in requests:
             # Extract input tensors
             # Extract input tensors
             wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
             wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
             wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
             wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
-            
+
             # Process reference audio through audio tokenizer
             # Process reference audio through audio tokenizer
 
 
             prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
             prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
             prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
             prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
 
 
-
             wav_tensor = wav.as_numpy()
             wav_tensor = wav.as_numpy()
             wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
             wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
             prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
             prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
@@ -275,20 +291,20 @@ class TritonPythonModel:
             token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
             token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
             prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
             prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
             prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
             prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
-            
+
             reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
             reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
             reference_text = reference_text[0][0].decode('utf-8')
             reference_text = reference_text[0][0].decode('utf-8')
-            
+
             target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
             target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
             target_text = target_text[0][0].decode('utf-8')
             target_text = target_text[0][0].decode('utf-8')
-            
+
             # Prepare prompt for LLM
             # Prepare prompt for LLM
             input_ids = self.parse_input(
             input_ids = self.parse_input(
                 text=target_text,
                 text=target_text,
                 prompt_text=reference_text,
                 prompt_text=reference_text,
                 prompt_speech_tokens=prompt_speech_tokens,
                 prompt_speech_tokens=prompt_speech_tokens,
             )
             )
-            
+
             # Generate semantic tokens with LLM
             # Generate semantic tokens with LLM
             generated_ids_iter = self.forward_llm(input_ids)
             generated_ids_iter = self.forward_llm(input_ids)
 
 
@@ -305,13 +321,13 @@ class TritonPythonModel:
                 generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
                 generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(torch.int32).to(self.device)
                 prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
                 prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
                 audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
                 audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
-                
+
                 # Prepare response
                 # Prepare response
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 response_sender.send(inference_response)
                 response_sender.send(inference_response)
                 response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                 response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
-                self.logger.log_info(f"send tritonserver_response_complete_final to end")
+                self.logger.log_info("send tritonserver_response_complete_final to end")
             else:
             else:
                 generated_ids = next(generated_ids_iter)
                 generated_ids = next(generated_ids_iter)
                 generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
                 generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
@@ -320,11 +336,11 @@ class TritonPythonModel:
 
 
                 prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
                 prompt_spk_embedding = self._extract_spk_embedding(wav_tensor)
                 audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
                 audio = self.forward_token2wav(prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, generated_ids)
-                
+
                 # Prepare response
                 # Prepare response
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
                 audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
                 responses.append(inference_response)
                 responses.append(inference_response)
-            
+
         if not self.decoupled:
         if not self.decoupled:
-            return responses
+            return responses

+ 11 - 13
runtime/triton_trtllm/model_repo/token2wav/1/model.py

@@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
 
 
 ORIGINAL_VOCAB_SIZE = 151663
 ORIGINAL_VOCAB_SIZE = 151663
 
 
+
 class CosyVoice2:
 class CosyVoice2:
 
 
     def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
     def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
@@ -66,6 +67,7 @@ class CosyVoice2:
                                 trt_concurrent,
                                 trt_concurrent,
                                 self.fp16)
                                 self.fp16)
 
 
+
 class CosyVoice2Model:
 class CosyVoice2Model:
 
 
     def __init__(self,
     def __init__(self,
@@ -109,16 +111,17 @@ class CosyVoice2Model:
         input_names = ["x", "mask", "mu", "cond"]
         input_names = ["x", "mask", "mu", "cond"]
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
 
 
+
 class TritonPythonModel:
 class TritonPythonModel:
     """Triton Python model for vocoder.
     """Triton Python model for vocoder.
-    
+
     This model takes global and semantic tokens as input and generates audio waveforms
     This model takes global and semantic tokens as input and generates audio waveforms
     using the BiCodec vocoder.
     using the BiCodec vocoder.
     """
     """
 
 
     def initialize(self, args):
     def initialize(self, args):
         """Initialize the model.
         """Initialize the model.
-        
+
         Args:
         Args:
             args: Dictionary containing model configuration
             args: Dictionary containing model configuration
         """
         """
@@ -126,24 +129,23 @@ class TritonPythonModel:
         parameters = json.loads(args['model_config'])['parameters']
         parameters = json.loads(args['model_config'])['parameters']
         model_params = {key: value["string_value"] for key, value in parameters.items()}
         model_params = {key: value["string_value"] for key, value in parameters.items()}
         model_dir = model_params["model_dir"]
         model_dir = model_params["model_dir"]
-        
+
         # Initialize device and vocoder
         # Initialize device and vocoder
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
         logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
-        
+
         self.token2wav_model = CosyVoice2(
         self.token2wav_model = CosyVoice2(
             model_dir, load_jit=True, load_trt=True, fp16=True
             model_dir, load_jit=True, load_trt=True, fp16=True
         )
         )
 
 
         logger.info("Token2Wav initialized successfully")
         logger.info("Token2Wav initialized successfully")
 
 
-
     def execute(self, requests):
     def execute(self, requests):
         """Execute inference on the batched requests.
         """Execute inference on the batched requests.
-        
+
         Args:
         Args:
             requests: List of inference requests
             requests: List of inference requests
-            
+
         Returns:
         Returns:
             List of inference responses containing generated waveforms
             List of inference responses containing generated waveforms
         """
         """
@@ -163,7 +165,7 @@ class TritonPythonModel:
             # shift the speech tokens according to the original vocab size
             # shift the speech tokens according to the original vocab size
             prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
             prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
             target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
             target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
-            
+
             tts_mel, _ = self.token2wav_model.model.flow.inference(
             tts_mel, _ = self.token2wav_model.model.flow.inference(
                 token=target_speech_tokens,
                 token=target_speech_tokens,
                 token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
                 token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
@@ -189,9 +191,5 @@ class TritonPythonModel:
             wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
             wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
             inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
             inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
             responses.append(inference_response)
             responses.append(inference_response)
-                             
-        return responses
-
-
-
 
 
+        return responses

+ 14 - 26
runtime/triton_trtllm/scripts/convert_checkpoint.py

@@ -35,8 +35,7 @@ def parse_arguments():
         type=str,
         type=str,
         default='auto',
         default='auto',
         choices=['auto', 'float16', 'bfloat16', 'float32'],
         choices=['auto', 'float16', 'bfloat16', 'float32'],
-        help=
-        "The data type for the model weights and activations if not quantized. "
+        help="The data type for the model weights and activations if not quantized. "
         "If 'auto', the data type is automatically inferred from the source model; "
         "If 'auto', the data type is automatically inferred from the source model; "
         "however, if the source dtype is float32, it is converted to float16.")
         "however, if the source dtype is float32, it is converted to float16.")
     parser.add_argument(
     parser.add_argument(
@@ -49,8 +48,7 @@ def parse_arguments():
         '--disable_weight_only_quant_plugin',
         '--disable_weight_only_quant_plugin',
         default=False,
         default=False,
         action="store_true",
         action="store_true",
-        help=
-        'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
+        help='By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
         'You must also use --use_weight_only for that argument to have an impact.'
         'You must also use --use_weight_only for that argument to have an impact.'
     )
     )
     parser.add_argument(
     parser.add_argument(
@@ -60,16 +58,14 @@ def parse_arguments():
         nargs='?',
         nargs='?',
         default='int8',
         default='int8',
         choices=['int8', 'int4', 'int4_gptq'],
         choices=['int8', 'int4', 'int4_gptq'],
-        help=
-        'Define the precision for the weights when using weight-only quantization.'
+        help='Define the precision for the weights when using weight-only quantization.'
         'You must also use --use_weight_only for that argument to have an impact.'
         'You must also use --use_weight_only for that argument to have an impact.'
     )
     )
     parser.add_argument(
     parser.add_argument(
         '--calib_dataset',
         '--calib_dataset',
         type=str,
         type=str,
         default='ccdv/cnn_dailymail',
         default='ccdv/cnn_dailymail',
-        help=
-        "The huggingface dataset name or the local directory of the dataset for calibration."
+        help="The huggingface dataset name or the local directory of the dataset for calibration."
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--smoothquant",
         "--smoothquant",
@@ -83,31 +79,27 @@ def parse_arguments():
         '--per_channel',
         '--per_channel',
         action="store_true",
         action="store_true",
         default=False,
         default=False,
-        help=
-        'By default, we use a single static scaling factor for the GEMM\'s result. '
+        help='By default, we use a single static scaling factor for the GEMM\'s result. '
         'per_channel instead uses a different static scaling factor for each channel. '
         'per_channel instead uses a different static scaling factor for each channel. '
         'The latter is usually more accurate, but a little slower.')
         'The latter is usually more accurate, but a little slower.')
     parser.add_argument(
     parser.add_argument(
         '--per_token',
         '--per_token',
         action="store_true",
         action="store_true",
         default=False,
         default=False,
-        help=
-        'By default, we use a single static scaling factor to scale activations in the int8 range. '
+        help='By default, we use a single static scaling factor to scale activations in the int8 range. '
         'per_token chooses at run time, and for each token, a custom scaling factor. '
         'per_token chooses at run time, and for each token, a custom scaling factor. '
         'The latter is usually more accurate, but a little slower.')
         'The latter is usually more accurate, but a little slower.')
     parser.add_argument(
     parser.add_argument(
         '--int8_kv_cache',
         '--int8_kv_cache',
         default=False,
         default=False,
         action="store_true",
         action="store_true",
-        help=
-        'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
+        help='By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
     )
     )
     parser.add_argument(
     parser.add_argument(
         '--per_group',
         '--per_group',
         default=False,
         default=False,
         action="store_true",
         action="store_true",
-        help=
-        'By default, we use a single static scaling factor to scale weights in the int4 range. '
+        help='By default, we use a single static scaling factor to scale weights in the int4 range. '
         'per_group chooses at run time, and for each group, a custom scaling factor. '
         'per_group chooses at run time, and for each group, a custom scaling factor. '
         'The flag is built for GPTQ/AWQ quantization.')
         'The flag is built for GPTQ/AWQ quantization.')
 
 
@@ -121,16 +113,14 @@ def parse_arguments():
         '--use_parallel_embedding',
         '--use_parallel_embedding',
         action="store_true",
         action="store_true",
         default=False,
         default=False,
-        help=
-        'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
+        help='By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
     )
     )
     parser.add_argument(
     parser.add_argument(
         '--embedding_sharding_dim',
         '--embedding_sharding_dim',
         type=int,
         type=int,
         default=0,
         default=0,
         choices=[0, 1],
         choices=[0, 1],
-        help=
-        'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
+        help='By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
         'To shard it along hidden dimension, set embedding_sharding_dim=1'
         'To shard it along hidden dimension, set embedding_sharding_dim=1'
         'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
         'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
     )
     )
@@ -147,15 +137,13 @@ def parse_arguments():
         '--moe_tp_size',
         '--moe_tp_size',
         type=int,
         type=int,
         default=-1,
         default=-1,
-        help=
-        'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
+        help='N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
     )
     )
     parser.add_argument(
     parser.add_argument(
         '--moe_ep_size',
         '--moe_ep_size',
         type=int,
         type=int,
         default=-1,
         default=-1,
-        help=
-        'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
+        help='N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
     )
     )
     args = parser.parse_args()
     args = parser.parse_args()
     return args
     return args
@@ -249,7 +237,7 @@ def convert_and_save_hf(args):
                                                trust_remote_code=True)
                                                trust_remote_code=True)
         quant_config, override_fields = update_quant_config_from_hf(
         quant_config, override_fields = update_quant_config_from_hf(
             quant_config, hf_config, override_fields)
             quant_config, hf_config, override_fields)
-    except:
+    except BaseException:
         logger.warning("AutoConfig cannot load the huggingface config.")
         logger.warning("AutoConfig cannot load the huggingface config.")
 
 
     if args.smoothquant is not None or args.int8_kv_cache:
     if args.smoothquant is not None or args.int8_kv_cache:
@@ -339,4 +327,4 @@ def main():
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    main()
+    main()

+ 2 - 3
runtime/triton_trtllm/scripts/fill_template.py

@@ -1,4 +1,4 @@
-#! /usr/bin/env python3
+# /usr/bin/env python3
 from argparse import ArgumentParser
 from argparse import ArgumentParser
 from string import Template
 from string import Template
 
 
@@ -59,8 +59,7 @@ if __name__ == "__main__":
     parser.add_argument("file_path", help="path of the .pbtxt to modify")
     parser.add_argument("file_path", help="path of the .pbtxt to modify")
     parser.add_argument(
     parser.add_argument(
         "substitutions",
         "substitutions",
-        help=
-        "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
+        help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
     )
     )
     parser.add_argument("--in_place",
     parser.add_argument("--in_place",
                         "-i",
                         "-i",

+ 1 - 2
runtime/triton_trtllm/scripts/test_llm.py

@@ -46,7 +46,6 @@ def parse_arguments(args=None):
     parser.add_argument('--top_k', type=int, default=50)
     parser.add_argument('--top_k', type=int, default=50)
     parser.add_argument('--top_p', type=float, default=0.95)
     parser.add_argument('--top_p', type=float, default=0.95)
 
 
-
     return parser.parse_args(args=args)
     return parser.parse_args(args=args)
 
 
 
 
@@ -60,7 +59,7 @@ def parse_input(tokenizer,
         input_ids = tokenizer.encode(
         input_ids = tokenizer.encode(
             curr_text)
             curr_text)
         batch_input_ids.append(input_ids)
         batch_input_ids.append(input_ids)
- 
+
     batch_input_ids = [
     batch_input_ids = [
         torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
         torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
     ]
     ]