|
|
@@ -59,12 +59,14 @@ import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
|
|
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
|
|
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
|
|
|
|
|
+from datetime import datetime
|
|
|
|
|
|
# --- Added UserData and callback ---
|
|
|
class UserData:
|
|
|
def __init__(self):
|
|
|
self._completed_requests = queue.Queue()
|
|
|
self._first_chunk_time = None
|
|
|
+ self._second_chunk_time = None
|
|
|
self._start_time = None
|
|
|
|
|
|
def record_start_time(self):
|
|
|
@@ -75,14 +77,44 @@ class UserData:
|
|
|
return self._first_chunk_time - self._start_time
|
|
|
return None
|
|
|
|
|
|
+ def get_second_chunk_latency(self):
|
|
|
+ if self._first_chunk_time and self._second_chunk_time:
|
|
|
+ return self._second_chunk_time - self._first_chunk_time
|
|
|
+ return None
|
|
|
+
|
|
|
|
|
|
def callback(user_data, result, error):
|
|
|
- if user_data._first_chunk_time is None and not error:
|
|
|
- user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
|
|
+ if not error:
|
|
|
+ if user_data._first_chunk_time is None:
|
|
|
+ user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
|
|
+ elif user_data._second_chunk_time is None:
|
|
|
+ user_data._second_chunk_time = time.time()
|
|
|
+
|
|
|
if error:
|
|
|
user_data._completed_requests.put(error)
|
|
|
else:
|
|
|
user_data._completed_requests.put(result)
|
|
|
+
|
|
|
+
|
|
|
+def stream_callback(user_data_map, result, error):
|
|
|
+ request_id = None
|
|
|
+ if error:
|
|
|
+ # Note: InferenceServerException doesn't have a public request_id() method in all versions.
|
|
|
+ # This part might need adjustment depending on the tritonclient library version.
|
|
|
+ # A more robust way would be to wrap the error with the request_id if possible.
|
|
|
+ # For now, we assume we can't get request_id from error and it will timeout on the client side.
|
|
|
+ print(f"An error occurred in the stream callback: {error}")
|
|
|
+ else:
|
|
|
+ request_id = result.get_response().id
|
|
|
+
|
|
|
+ if request_id:
|
|
|
+ user_data = user_data_map.get(request_id)
|
|
|
+ if user_data:
|
|
|
+ callback(user_data, result, error)
|
|
|
+ else:
|
|
|
+ print(f"Warning: Could not find user_data for request_id {request_id}")
|
|
|
+
|
|
|
+
|
|
|
# --- End Added UserData and callback ---
|
|
|
|
|
|
|
|
|
@@ -142,6 +174,68 @@ def write_triton_stats(stats, summary_file):
|
|
|
)
|
|
|
|
|
|
|
|
|
+def subtract_stats(stats_after, stats_before):
|
|
|
+ """Subtracts two Triton inference statistics objects."""
|
|
|
+ # Deep copy to avoid modifying the original stats_after
|
|
|
+ stats_diff = json.loads(json.dumps(stats_after))
|
|
|
+
|
|
|
+ model_stats_before_map = {
|
|
|
+ s["name"]: {
|
|
|
+ "version": s["version"],
|
|
|
+ "last_inference": s.get("last_inference", 0),
|
|
|
+ "inference_count": s.get("inference_count", 0),
|
|
|
+ "execution_count": s.get("execution_count", 0),
|
|
|
+ "inference_stats": s.get("inference_stats", {}),
|
|
|
+ "batch_stats": s.get("batch_stats", []),
|
|
|
+ }
|
|
|
+ for s in stats_before["model_stats"]
|
|
|
+ }
|
|
|
+
|
|
|
+ for model_stat_after in stats_diff["model_stats"]:
|
|
|
+ model_name = model_stat_after["name"]
|
|
|
+ if model_name in model_stats_before_map:
|
|
|
+ model_stat_before = model_stats_before_map[model_name]
|
|
|
+
|
|
|
+ # Subtract counts
|
|
|
+ model_stat_after["inference_count"] = str(
|
|
|
+ int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
|
|
|
+ )
|
|
|
+ model_stat_after["execution_count"] = str(
|
|
|
+ int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
|
|
|
+ )
|
|
|
+
|
|
|
+ # Subtract aggregate stats (like queue, compute times)
|
|
|
+ if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
|
|
|
+ for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
|
|
|
+ if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
|
|
|
+ if "ns" in model_stat_after["inference_stats"][key]:
|
|
|
+ ns_after = int(model_stat_after["inference_stats"][key]["ns"])
|
|
|
+ ns_before = int(model_stat_before["inference_stats"][key]["ns"])
|
|
|
+ model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
|
|
|
+ if "count" in model_stat_after["inference_stats"][key]:
|
|
|
+ count_after = int(model_stat_after["inference_stats"][key]["count"])
|
|
|
+ count_before = int(model_stat_before["inference_stats"][key]["count"])
|
|
|
+ model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
|
|
|
+
|
|
|
+ # Subtract batch execution stats
|
|
|
+ if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
|
|
|
+ batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
|
|
|
+ for batch_stat_after in model_stat_after["batch_stats"]:
|
|
|
+ bs = batch_stat_after["batch_size"]
|
|
|
+ if bs in batch_stats_before_map:
|
|
|
+ batch_stat_before = batch_stats_before_map[bs]
|
|
|
+ for key in ["compute_input", "compute_infer", "compute_output"]:
|
|
|
+ if key in batch_stat_after and key in batch_stat_before:
|
|
|
+ count_after = int(batch_stat_after[key]["count"])
|
|
|
+ count_before = int(batch_stat_before[key]["count"])
|
|
|
+ batch_stat_after[key]["count"] = str(count_after - count_before)
|
|
|
+
|
|
|
+ ns_after = int(batch_stat_after[key]["ns"])
|
|
|
+ ns_before = int(batch_stat_before[key]["ns"])
|
|
|
+ batch_stat_after[key]["ns"] = str(ns_after - ns_before)
|
|
|
+ return stats_diff
|
|
|
+
|
|
|
+
|
|
|
def get_args():
|
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
|
|
@@ -357,10 +451,10 @@ def run_sync_streaming_inference(
|
|
|
"""Helper function to run the blocking sync streaming call."""
|
|
|
start_time_total = time.time()
|
|
|
user_data.record_start_time() # Record start time for first chunk latency calculation
|
|
|
+ # e.g. 08:47:34.827758
|
|
|
|
|
|
- # Establish stream
|
|
|
- sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
|
|
-
|
|
|
+ print(f"Record start time in human readable: {datetime.now()}")
|
|
|
+ # input()
|
|
|
# Send request
|
|
|
sync_triton_client.async_stream_infer(
|
|
|
model_name,
|
|
|
@@ -374,11 +468,11 @@ def run_sync_streaming_inference(
|
|
|
audios = []
|
|
|
while True:
|
|
|
try:
|
|
|
- result = user_data._completed_requests.get() # Add timeout
|
|
|
+ result = user_data._completed_requests.get(timeout=20) # Add timeout
|
|
|
if isinstance(result, InferenceServerException):
|
|
|
print(f"Received InferenceServerException: {result}")
|
|
|
- sync_triton_client.stop_stream()
|
|
|
- return None, None, None # Indicate error
|
|
|
+ # Don't stop the stream here, just return error
|
|
|
+ return None, None, None, None
|
|
|
# Get response metadata
|
|
|
response = result.get_response()
|
|
|
final = response.parameters["triton_final_response"].bool_param
|
|
|
@@ -393,13 +487,13 @@ def run_sync_streaming_inference(
|
|
|
|
|
|
except queue.Empty:
|
|
|
print(f"Timeout waiting for response for request id {request_id}")
|
|
|
- sync_triton_client.stop_stream()
|
|
|
- return None, None, None # Indicate error
|
|
|
+ # Don't stop stream here, just return error
|
|
|
+ return None, None, None, None
|
|
|
|
|
|
- sync_triton_client.stop_stream()
|
|
|
end_time_total = time.time()
|
|
|
total_request_latency = end_time_total - start_time_total
|
|
|
first_chunk_latency = user_data.get_first_chunk_latency()
|
|
|
+ second_chunk_latency = user_data.get_second_chunk_latency()
|
|
|
|
|
|
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
|
|
actual_duration = 0
|
|
|
@@ -448,7 +542,7 @@ def run_sync_streaming_inference(
|
|
|
print("Warning: No audio chunks received.")
|
|
|
actual_duration = 0
|
|
|
|
|
|
- return total_request_latency, first_chunk_latency, actual_duration
|
|
|
+ return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
|
|
|
|
|
|
|
|
|
async def send_streaming(
|
|
|
@@ -468,10 +562,12 @@ async def send_streaming(
|
|
|
latency_data = []
|
|
|
task_id = int(name[5:])
|
|
|
sync_triton_client = None # Initialize client variable
|
|
|
+ user_data_map = {}
|
|
|
|
|
|
try: # Wrap in try...finally to ensure client closing
|
|
|
print(f"{name}: Initializing sync client for streaming...")
|
|
|
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
|
|
+ sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
|
|
|
|
|
|
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
|
|
for i, item in enumerate(manifest_item_list):
|
|
|
@@ -494,10 +590,11 @@ async def send_streaming(
|
|
|
|
|
|
request_id = str(uuid.uuid4())
|
|
|
user_data = UserData()
|
|
|
+ user_data_map[request_id] = user_data
|
|
|
|
|
|
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
|
|
-
|
|
|
- total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
|
|
|
+ print("target_text: ", target_text, "time: ", datetime.now())
|
|
|
+ total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
|
|
|
run_sync_streaming_inference,
|
|
|
sync_triton_client,
|
|
|
model_name,
|
|
|
@@ -511,12 +608,18 @@ async def send_streaming(
|
|
|
)
|
|
|
|
|
|
if total_request_latency is not None:
|
|
|
- print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
|
|
|
- latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
|
|
+ print(
|
|
|
+ f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
|
|
|
+ f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
|
|
|
+ f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
|
|
|
+ )
|
|
|
+ latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
|
|
|
total_duration += actual_duration
|
|
|
else:
|
|
|
print(f"{name}: Item {i} failed.")
|
|
|
|
|
|
+ del user_data_map[request_id]
|
|
|
+
|
|
|
except FileNotFoundError:
|
|
|
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
|
|
except Exception as e:
|
|
|
@@ -527,7 +630,8 @@ async def send_streaming(
|
|
|
finally: # Ensure client is closed
|
|
|
if sync_triton_client:
|
|
|
try:
|
|
|
- print(f"{name}: Closing sync client...")
|
|
|
+ print(f"{name}: Closing stream and sync client...")
|
|
|
+ sync_triton_client.stop_stream()
|
|
|
sync_triton_client.close()
|
|
|
except Exception as e:
|
|
|
print(f"{name}: Error closing sync client: {e}")
|
|
|
@@ -685,9 +789,22 @@ async def main():
|
|
|
"target_text": dataset[i]["target_text"],
|
|
|
}
|
|
|
)
|
|
|
+ # manifest_item_list = manifest_item_list[:4]
|
|
|
else:
|
|
|
manifest_item_list = load_manifests(args.manifest_path)
|
|
|
|
|
|
+ # --- Statistics Fetching (Before) ---
|
|
|
+ stats_client = None
|
|
|
+ stats_before = None
|
|
|
+ try:
|
|
|
+ print("Initializing temporary async client for fetching stats...")
|
|
|
+ stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
|
|
+ print("Fetching inference statistics before running tasks...")
|
|
|
+ stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Could not retrieve statistics before running tasks: {e}")
|
|
|
+ # --- End Statistics Fetching (Before) ---
|
|
|
+
|
|
|
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
|
|
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
|
|
|
|
|
@@ -776,8 +893,9 @@ async def main():
|
|
|
|
|
|
elif args.mode == "streaming":
|
|
|
# Calculate stats for total request latency and first chunk latency
|
|
|
- total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
|
|
|
- first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
|
|
|
+ total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
|
|
|
+ first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
|
|
|
+ second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
|
|
|
|
|
|
s += "\n--- Total Request Latency ---\n"
|
|
|
if total_latency_list:
|
|
|
@@ -804,6 +922,19 @@ async def main():
|
|
|
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
|
|
else:
|
|
|
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
|
|
+
|
|
|
+ s += "\n--- Second Chunk Latency ---\n"
|
|
|
+ if second_chunk_latency_list:
|
|
|
+ avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
|
|
|
+ variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
|
|
|
+ s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
|
|
|
+ s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
|
|
+ s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
|
|
+ s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
|
|
+ s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
|
|
+ s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
|
|
|
+ else:
|
|
|
+ s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
|
|
|
else:
|
|
|
s += "No latency data collected.\n"
|
|
|
# --- End Statistics Reporting ---
|
|
|
@@ -822,20 +953,23 @@ async def main():
|
|
|
|
|
|
# --- Statistics Fetching using temporary Async Client ---
|
|
|
# Use a separate async client for fetching stats regardless of mode
|
|
|
- stats_client = None
|
|
|
try:
|
|
|
- print("Initializing temporary async client for fetching stats...")
|
|
|
- stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
|
|
- print("Fetching inference statistics...")
|
|
|
- # Fetching for all models, filtering might be needed depending on server setup
|
|
|
- stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
|
|
- print("Fetching model config...")
|
|
|
- metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
|
|
+ if stats_client and stats_before:
|
|
|
+ print("Fetching inference statistics after running tasks...")
|
|
|
+ stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
|
|
+
|
|
|
+ print("Calculating statistics difference...")
|
|
|
+ stats = subtract_stats(stats_after, stats_before)
|
|
|
|
|
|
- write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
|
|
+ print("Fetching model config...")
|
|
|
+ metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
|
|
|
|
|
- with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
|
|
- json.dump(metadata, f, indent=4)
|
|
|
+ write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
|
|
+
|
|
|
+ with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
|
|
+ json.dump(metadata, f, indent=4)
|
|
|
+ else:
|
|
|
+ print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Could not retrieve statistics or config: {e}")
|