|
|
@@ -43,9 +43,9 @@ python3 client_grpc.py \
|
|
|
import argparse
|
|
|
import asyncio
|
|
|
import json
|
|
|
-import queue # Added
|
|
|
-import uuid # Added
|
|
|
-import functools # Added
|
|
|
+import queue
|
|
|
+import uuid
|
|
|
+import functools
|
|
|
|
|
|
import os
|
|
|
import time
|
|
|
@@ -55,13 +55,11 @@ from pathlib import Path
|
|
|
import numpy as np
|
|
|
import soundfile as sf
|
|
|
import tritonclient
|
|
|
-import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
|
|
-import tritonclient.grpc as grpcclient_sync # Added sync client import
|
|
|
-from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
|
|
+import tritonclient.grpc.aio as grpcclient_aio
|
|
|
+import tritonclient.grpc as grpcclient_sync
|
|
|
+from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
|
|
|
|
|
-from datetime import datetime
|
|
|
|
|
|
-# --- Added UserData and callback ---
|
|
|
class UserData:
|
|
|
def __init__(self):
|
|
|
self._completed_requests = queue.Queue()
|
|
|
@@ -86,7 +84,7 @@ class UserData:
|
|
|
def callback(user_data, result, error):
|
|
|
if not error:
|
|
|
if user_data._first_chunk_time is None:
|
|
|
- user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
|
|
+ user_data._first_chunk_time = time.time()
|
|
|
elif user_data._second_chunk_time is None:
|
|
|
user_data._second_chunk_time = time.time()
|
|
|
|
|
|
@@ -99,10 +97,6 @@ def callback(user_data, result, error):
|
|
|
def stream_callback(user_data_map, result, error):
|
|
|
request_id = None
|
|
|
if error:
|
|
|
- # Note: InferenceServerException doesn't have a public request_id() method in all versions.
|
|
|
- # This part might need adjustment depending on the tritonclient library version.
|
|
|
- # A more robust way would be to wrap the error with the request_id if possible.
|
|
|
- # For now, we assume we can't get request_id from error and it will timeout on the client side.
|
|
|
print(f"An error occurred in the stream callback: {error}")
|
|
|
else:
|
|
|
request_id = result.get_response().id
|
|
|
@@ -115,31 +109,9 @@ def stream_callback(user_data_map, result, error):
|
|
|
print(f"Warning: Could not find user_data for request_id {request_id}")
|
|
|
|
|
|
|
|
|
-# --- End Added UserData and callback ---
|
|
|
-
|
|
|
-
|
|
|
def write_triton_stats(stats, summary_file):
|
|
|
with open(summary_file, "w") as summary_f:
|
|
|
model_stats = stats["model_stats"]
|
|
|
- # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
|
|
|
- summary_f.write(
|
|
|
- "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
|
|
- )
|
|
|
- summary_f.write("To learn more about the log, please refer to: \n")
|
|
|
- summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
|
|
|
- summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
|
|
|
- summary_f.write(
|
|
|
- "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
|
|
- )
|
|
|
- summary_f.write(
|
|
|
- "However, there is a trade-off between the increased queue time and the increased batch size. \n"
|
|
|
- )
|
|
|
- summary_f.write(
|
|
|
- "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
|
|
|
- )
|
|
|
- summary_f.write(
|
|
|
- "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
|
|
|
- )
|
|
|
for model_state in model_stats:
|
|
|
if "last_inference" not in model_state:
|
|
|
continue
|
|
|
@@ -150,7 +122,7 @@ def write_triton_stats(stats, summary_file):
|
|
|
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
|
|
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
|
|
summary_f.write(
|
|
|
- f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
|
|
+ f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"
|
|
|
)
|
|
|
model_batch_stats = model_state["batch_stats"]
|
|
|
for batch in model_batch_stats:
|
|
|
@@ -164,19 +136,18 @@ def write_triton_stats(stats, summary_file):
|
|
|
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
|
|
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
|
|
summary_f.write(
|
|
|
- f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
|
|
|
+ f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
|
|
|
)
|
|
|
summary_f.write(
|
|
|
- f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
|
|
|
+ f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
|
|
|
)
|
|
|
summary_f.write(
|
|
|
- f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
|
|
|
+ f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
|
|
|
)
|
|
|
|
|
|
|
|
|
def subtract_stats(stats_after, stats_before):
|
|
|
"""Subtracts two Triton inference statistics objects."""
|
|
|
- # Deep copy to avoid modifying the original stats_after
|
|
|
stats_diff = json.loads(json.dumps(stats_after))
|
|
|
|
|
|
model_stats_before_map = {
|
|
|
@@ -196,7 +167,6 @@ def subtract_stats(stats_after, stats_before):
|
|
|
if model_name in model_stats_before_map:
|
|
|
model_stat_before = model_stats_before_map[model_name]
|
|
|
|
|
|
- # Subtract counts
|
|
|
model_stat_after["inference_count"] = str(
|
|
|
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
|
|
|
)
|
|
|
@@ -204,7 +174,6 @@ def subtract_stats(stats_after, stats_before):
|
|
|
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
|
|
|
)
|
|
|
|
|
|
- # Subtract aggregate stats (like queue, compute times)
|
|
|
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
|
|
|
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
|
|
|
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
|
|
|
@@ -217,7 +186,6 @@ def subtract_stats(stats_after, stats_before):
|
|
|
count_before = int(model_stat_before["inference_stats"][key]["count"])
|
|
|
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
|
|
|
|
|
|
- # Subtract batch execution stats
|
|
|
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
|
|
|
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
|
|
|
for batch_stat_after in model_stat_after["batch_stats"]:
|
|
|
@@ -338,7 +306,6 @@ def get_args():
|
|
|
help="log directory",
|
|
|
)
|
|
|
|
|
|
- # --- Added arguments ---
|
|
|
parser.add_argument(
|
|
|
"--mode",
|
|
|
type=str,
|
|
|
@@ -379,39 +346,33 @@ def load_audio(wav_path, target_sample_rate=16000):
|
|
|
|
|
|
|
|
|
def prepare_request_input_output(
|
|
|
- protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
|
|
+ protocol_client,
|
|
|
waveform,
|
|
|
reference_text,
|
|
|
target_text,
|
|
|
sample_rate=16000,
|
|
|
- padding_duration: int = None, # Optional padding for offline mode
|
|
|
+ padding_duration: int = None,
|
|
|
use_spk2info_cache: bool = False
|
|
|
):
|
|
|
"""Prepares inputs for Triton inference (offline or streaming)."""
|
|
|
assert len(waveform.shape) == 1, "waveform should be 1D"
|
|
|
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
|
|
|
|
|
- # Apply padding only if padding_duration is provided (for offline)
|
|
|
if padding_duration:
|
|
|
duration = len(waveform) / sample_rate
|
|
|
- # Estimate target duration based on text length ratio (crude estimation)
|
|
|
- # Avoid division by zero if reference_text is empty
|
|
|
if reference_text:
|
|
|
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
|
|
else:
|
|
|
- estimated_target_duration = duration # Assume target duration similar to reference if no text
|
|
|
+ estimated_target_duration = duration
|
|
|
|
|
|
- # Calculate required samples based on estimated total duration
|
|
|
required_total_samples = padding_duration * sample_rate * (
|
|
|
(int(estimated_target_duration + duration) // padding_duration) + 1
|
|
|
)
|
|
|
samples = np.zeros((1, required_total_samples), dtype=np.float32)
|
|
|
samples[0, : len(waveform)] = waveform
|
|
|
else:
|
|
|
- # No padding for streaming or if padding_duration is None
|
|
|
samples = waveform.reshape(1, -1).astype(np.float32)
|
|
|
|
|
|
- # Common input creation logic
|
|
|
inputs = [
|
|
|
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
|
|
protocol_client.InferInput(
|
|
|
@@ -450,12 +411,8 @@ def run_sync_streaming_inference(
|
|
|
):
|
|
|
"""Helper function to run the blocking sync streaming call."""
|
|
|
start_time_total = time.time()
|
|
|
- user_data.record_start_time() # Record start time for first chunk latency calculation
|
|
|
- # e.g. 08:47:34.827758
|
|
|
+ user_data.record_start_time()
|
|
|
|
|
|
- print(f"Record start time in human readable: {datetime.now()}")
|
|
|
- # input()
|
|
|
- # Send request
|
|
|
sync_triton_client.async_stream_infer(
|
|
|
model_name,
|
|
|
inputs,
|
|
|
@@ -464,30 +421,26 @@ def run_sync_streaming_inference(
|
|
|
enable_empty_final_response=True,
|
|
|
)
|
|
|
|
|
|
- # Process results
|
|
|
audios = []
|
|
|
while True:
|
|
|
try:
|
|
|
- result = user_data._completed_requests.get(timeout=20) # Add timeout
|
|
|
+ result = user_data._completed_requests.get(timeout=20)
|
|
|
if isinstance(result, InferenceServerException):
|
|
|
print(f"Received InferenceServerException: {result}")
|
|
|
- # Don't stop the stream here, just return error
|
|
|
return None, None, None, None
|
|
|
- # Get response metadata
|
|
|
response = result.get_response()
|
|
|
final = response.parameters["triton_final_response"].bool_param
|
|
|
if final is True:
|
|
|
break
|
|
|
|
|
|
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
|
|
- if audio_chunk.size > 0: # Only append non-empty chunks
|
|
|
+ if audio_chunk.size > 0:
|
|
|
audios.append(audio_chunk)
|
|
|
else:
|
|
|
print("Warning: received empty audio chunk.")
|
|
|
|
|
|
except queue.Empty:
|
|
|
print(f"Timeout waiting for response for request id {request_id}")
|
|
|
- # Don't stop stream here, just return error
|
|
|
return None, None, None, None
|
|
|
|
|
|
end_time_total = time.time()
|
|
|
@@ -495,47 +448,36 @@ def run_sync_streaming_inference(
|
|
|
first_chunk_latency = user_data.get_first_chunk_latency()
|
|
|
second_chunk_latency = user_data.get_second_chunk_latency()
|
|
|
|
|
|
- # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
|
|
- actual_duration = 0
|
|
|
if audios:
|
|
|
- # Only spark_tts model uses cross-fade
|
|
|
if model_name == "spark_tts":
|
|
|
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
|
|
fade_out = np.linspace(1, 0, cross_fade_samples)
|
|
|
fade_in = np.linspace(0, 1, cross_fade_samples)
|
|
|
reconstructed_audio = None
|
|
|
|
|
|
- # Simplified reconstruction based on client_grpc_streaming.py
|
|
|
if not audios:
|
|
|
print("Warning: No audio chunks received.")
|
|
|
- reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
|
|
+ reconstructed_audio = np.array([], dtype=np.float32)
|
|
|
elif len(audios) == 1:
|
|
|
reconstructed_audio = audios[0]
|
|
|
else:
|
|
|
- reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
|
|
+ reconstructed_audio = audios[0][:-cross_fade_samples]
|
|
|
for i in range(1, len(audios)):
|
|
|
- # Cross-fade section
|
|
|
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
|
|
audios[i - 1][-cross_fade_samples:] * fade_out)
|
|
|
- # Middle section of the current chunk
|
|
|
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
|
|
- # Concatenate
|
|
|
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
|
|
- # Add the last part of the final chunk
|
|
|
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
|
|
|
|
|
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
|
|
actual_duration = len(reconstructed_audio) / save_sample_rate
|
|
|
- # Save reconstructed audio
|
|
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
|
|
else:
|
|
|
print("Warning: No audio chunks received or reconstructed.")
|
|
|
- actual_duration = 0 # Set duration to 0 if no audio
|
|
|
+ actual_duration = 0
|
|
|
else:
|
|
|
reconstructed_audio = np.concatenate(audios)
|
|
|
- print(f"reconstructed_audio: {reconstructed_audio.shape}")
|
|
|
actual_duration = len(reconstructed_audio) / save_sample_rate
|
|
|
- # Save reconstructed audio
|
|
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
|
|
|
|
|
else:
|
|
|
@@ -548,7 +490,7 @@ def run_sync_streaming_inference(
|
|
|
async def send_streaming(
|
|
|
manifest_item_list: list,
|
|
|
name: str,
|
|
|
- server_url: str, # Changed from sync_triton_client
|
|
|
+ server_url: str,
|
|
|
protocol_client: types.ModuleType,
|
|
|
log_interval: int,
|
|
|
model_name: str,
|
|
|
@@ -561,12 +503,12 @@ async def send_streaming(
|
|
|
total_duration = 0.0
|
|
|
latency_data = []
|
|
|
task_id = int(name[5:])
|
|
|
- sync_triton_client = None # Initialize client variable
|
|
|
+ sync_triton_client = None
|
|
|
user_data_map = {}
|
|
|
|
|
|
- try: # Wrap in try...finally to ensure client closing
|
|
|
+ try:
|
|
|
print(f"{name}: Initializing sync client for streaming...")
|
|
|
- sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
|
|
+ sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
|
|
|
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
|
|
|
|
|
|
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
|
|
@@ -593,7 +535,6 @@ async def send_streaming(
|
|
|
user_data_map[request_id] = user_data
|
|
|
|
|
|
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
|
|
- print("target_text: ", target_text, "time: ", datetime.now())
|
|
|
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
|
|
|
run_sync_streaming_inference,
|
|
|
sync_triton_client,
|
|
|
@@ -627,7 +568,7 @@ async def send_streaming(
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
|
|
|
- finally: # Ensure client is closed
|
|
|
+ finally:
|
|
|
if sync_triton_client:
|
|
|
try:
|
|
|
print(f"{name}: Closing stream and sync client...")
|
|
|
@@ -656,7 +597,6 @@ async def send(
|
|
|
latency_data = []
|
|
|
task_id = int(name[5:])
|
|
|
|
|
|
- print(f"manifest_item_list: {manifest_item_list}")
|
|
|
for i, item in enumerate(manifest_item_list):
|
|
|
if i % log_interval == 0:
|
|
|
print(f"{name}: {i}/{len(manifest_item_list)}")
|
|
|
@@ -697,7 +637,6 @@ def load_manifests(manifest_path):
|
|
|
assert len(line.strip().split("|")) == 4
|
|
|
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
|
|
utt = Path(utt).stem
|
|
|
- # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
|
|
|
if not os.path.isabs(prompt_wav):
|
|
|
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
|
|
manifest_list.append(
|
|
|
@@ -738,23 +677,17 @@ async def main():
|
|
|
args = get_args()
|
|
|
url = f"{args.server_addr}:{args.server_port}"
|
|
|
|
|
|
- # --- Client Initialization based on mode ---
|
|
|
triton_client = None
|
|
|
protocol_client = None
|
|
|
if args.mode == "offline":
|
|
|
print("Initializing gRPC client for offline mode...")
|
|
|
- # Use the async client for offline tasks
|
|
|
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
|
|
protocol_client = grpcclient_aio
|
|
|
elif args.mode == "streaming":
|
|
|
print("Initializing gRPC client for streaming mode...")
|
|
|
- # Use the sync client for streaming tasks, handled via asyncio.to_thread
|
|
|
- # We will create one sync client instance PER TASK inside send_streaming.
|
|
|
- # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
|
|
|
- protocol_client = grpcclient_sync # protocol client for input prep
|
|
|
+ protocol_client = grpcclient_sync
|
|
|
else:
|
|
|
raise ValueError(f"Invalid mode: {args.mode}")
|
|
|
- # --- End Client Initialization ---
|
|
|
|
|
|
if args.reference_audio:
|
|
|
args.num_tasks = 1
|
|
|
@@ -776,24 +709,18 @@ async def main():
|
|
|
trust_remote_code=True,
|
|
|
)
|
|
|
manifest_item_list = []
|
|
|
- tmp_audio_path="./asset_zero_shot_prompt.wav"
|
|
|
- tmp_audio_text="希望你以后能够做的比我还好呦。"
|
|
|
for i in range(len(dataset)):
|
|
|
manifest_item_list.append(
|
|
|
{
|
|
|
"audio_filepath": dataset[i]["prompt_audio"],
|
|
|
"reference_text": dataset[i]["prompt_text"],
|
|
|
- # "audio_filepath": tmp_audio_path,
|
|
|
- # "reference_text": tmp_audio_text,
|
|
|
"target_audio_path": dataset[i]["id"],
|
|
|
"target_text": dataset[i]["target_text"],
|
|
|
}
|
|
|
)
|
|
|
- # manifest_item_list = manifest_item_list[:4]
|
|
|
else:
|
|
|
manifest_item_list = load_manifests(args.manifest_path)
|
|
|
|
|
|
- # --- Statistics Fetching (Before) ---
|
|
|
stats_client = None
|
|
|
stats_before = None
|
|
|
try:
|
|
|
@@ -803,7 +730,6 @@ async def main():
|
|
|
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
|
|
except Exception as e:
|
|
|
print(f"Could not retrieve statistics before running tasks: {e}")
|
|
|
- # --- End Statistics Fetching (Before) ---
|
|
|
|
|
|
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
|
|
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
|
|
@@ -813,7 +739,6 @@ async def main():
|
|
|
tasks = []
|
|
|
start_time = time.time()
|
|
|
for i in range(num_tasks):
|
|
|
- # --- Task Creation based on mode ---
|
|
|
if args.mode == "offline":
|
|
|
task = asyncio.create_task(
|
|
|
send(
|
|
|
@@ -834,7 +759,7 @@ async def main():
|
|
|
send_streaming(
|
|
|
manifest_item_list[i],
|
|
|
name=f"task-{i}",
|
|
|
- server_url=url, # Pass URL instead of client
|
|
|
+ server_url=url,
|
|
|
protocol_client=protocol_client,
|
|
|
log_interval=args.log_interval,
|
|
|
model_name=args.model_name,
|
|
|
@@ -845,7 +770,6 @@ async def main():
|
|
|
use_spk2info_cache=args.use_spk2info_cache,
|
|
|
)
|
|
|
)
|
|
|
- # --- End Task Creation ---
|
|
|
tasks.append(task)
|
|
|
|
|
|
ans_list = await asyncio.gather(*tasks)
|
|
|
@@ -858,7 +782,7 @@ async def main():
|
|
|
for ans in ans_list:
|
|
|
if ans:
|
|
|
total_duration += ans[0]
|
|
|
- latency_data.extend(ans[1]) # Use extend for list of lists
|
|
|
+ latency_data.extend(ans[1])
|
|
|
else:
|
|
|
print("Warning: A task returned None, possibly due to an error.")
|
|
|
|
|
|
@@ -874,10 +798,8 @@ async def main():
|
|
|
s += f"({total_duration / 3600:.2f} hours)\n"
|
|
|
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
|
|
|
|
|
- # --- Statistics Reporting based on mode ---
|
|
|
if latency_data:
|
|
|
if args.mode == "offline":
|
|
|
- # Original offline latency calculation
|
|
|
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
|
|
if latency_list:
|
|
|
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
|
|
@@ -892,7 +814,6 @@ async def main():
|
|
|
s += "No latency data collected for offline mode.\n"
|
|
|
|
|
|
elif args.mode == "streaming":
|
|
|
- # Calculate stats for total request latency and first chunk latency
|
|
|
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
|
|
|
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
|
|
|
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
|
|
|
@@ -937,7 +858,6 @@ async def main():
|
|
|
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
|
|
|
else:
|
|
|
s += "No latency data collected.\n"
|
|
|
- # --- End Statistics Reporting ---
|
|
|
|
|
|
print(s)
|
|
|
if args.manifest_path:
|
|
|
@@ -947,12 +867,10 @@ async def main():
|
|
|
elif args.reference_audio:
|
|
|
name = Path(args.reference_audio).stem
|
|
|
else:
|
|
|
- name = "results" # Default name if no manifest/split/audio provided
|
|
|
+ name = "results"
|
|
|
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
|
|
f.write(s)
|
|
|
|
|
|
- # --- Statistics Fetching using temporary Async Client ---
|
|
|
- # Use a separate async client for fetching stats regardless of mode
|
|
|
try:
|
|
|
if stats_client and stats_before:
|
|
|
print("Fetching inference statistics after running tasks...")
|
|
|
@@ -980,11 +898,9 @@ async def main():
|
|
|
await stats_client.close()
|
|
|
except Exception as e:
|
|
|
print(f"Error closing async stats client: {e}")
|
|
|
- # --- End Statistics Fetching ---
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
|
|
|
async def run_main():
|
|
|
try:
|
|
|
await main()
|