|
|
@@ -1,4 +1,3 @@
|
|
|
-#!/usr/bin/env python3
|
|
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
|
# 2023 Nvidia (authors: Yuekai Zhang)
|
|
|
# 2023 Recurrent.ai (authors: Songtao Shi)
|
|
|
@@ -46,7 +45,7 @@ import asyncio
|
|
|
import json
|
|
|
import queue # Added
|
|
|
import uuid # Added
|
|
|
-import functools # Added
|
|
|
+import functools # Added
|
|
|
|
|
|
import os
|
|
|
import time
|
|
|
@@ -56,9 +55,9 @@ from pathlib import Path
|
|
|
import numpy as np
|
|
|
import soundfile as sf
|
|
|
import tritonclient
|
|
|
-import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
|
|
-import tritonclient.grpc as grpcclient_sync # Added sync client import
|
|
|
-from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
|
|
+import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
|
|
+import tritonclient.grpc as grpcclient_sync # Added sync client import
|
|
|
+from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
|
|
|
|
|
|
|
|
# --- Added UserData and callback ---
|
|
|
@@ -76,9 +75,10 @@ class UserData:
|
|
|
return self._first_chunk_time - self._start_time
|
|
|
return None
|
|
|
|
|
|
+
|
|
|
def callback(user_data, result, error):
|
|
|
if user_data._first_chunk_time is None and not error:
|
|
|
- user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
|
|
+ user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
|
|
if error:
|
|
|
user_data._completed_requests.put(error)
|
|
|
else:
|
|
|
@@ -206,8 +206,11 @@ def get_args():
|
|
|
"--model-name",
|
|
|
type=str,
|
|
|
default="f5_tts",
|
|
|
- choices=["f5_tts", "spark_tts", "cosyvoice2"],
|
|
|
- help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
|
|
+ choices=[
|
|
|
+ "f5_tts",
|
|
|
+ "spark_tts",
|
|
|
+ "cosyvoice2"],
|
|
|
+ help="triton model_repo module name to request",
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
@@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000):
|
|
|
waveform = resample(waveform, num_samples)
|
|
|
return waveform, target_sample_rate
|
|
|
|
|
|
+
|
|
|
def prepare_request_input_output(
|
|
|
- protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
|
|
+ protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
|
|
waveform,
|
|
|
reference_text,
|
|
|
target_text,
|
|
|
sample_rate=16000,
|
|
|
- padding_duration: int = None # Optional padding for offline mode
|
|
|
+ padding_duration: int = None # Optional padding for offline mode
|
|
|
):
|
|
|
"""Prepares inputs for Triton inference (offline or streaming)."""
|
|
|
assert len(waveform.shape) == 1, "waveform should be 1D"
|
|
|
@@ -291,9 +295,9 @@ def prepare_request_input_output(
|
|
|
# Estimate target duration based on text length ratio (crude estimation)
|
|
|
# Avoid division by zero if reference_text is empty
|
|
|
if reference_text:
|
|
|
- estimated_target_duration = duration / len(reference_text) * len(target_text)
|
|
|
+ estimated_target_duration = duration / len(reference_text) * len(target_text)
|
|
|
else:
|
|
|
- estimated_target_duration = duration # Assume target duration similar to reference if no text
|
|
|
+ estimated_target_duration = duration # Assume target duration similar to reference if no text
|
|
|
|
|
|
# Calculate required samples based on estimated total duration
|
|
|
required_total_samples = padding_duration * sample_rate * (
|
|
|
@@ -329,6 +333,7 @@ def prepare_request_input_output(
|
|
|
|
|
|
return inputs, outputs
|
|
|
|
|
|
+
|
|
|
def run_sync_streaming_inference(
|
|
|
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
|
|
model_name: str,
|
|
|
@@ -342,7 +347,7 @@ def run_sync_streaming_inference(
|
|
|
):
|
|
|
"""Helper function to run the blocking sync streaming call."""
|
|
|
start_time_total = time.time()
|
|
|
- user_data.record_start_time() # Record start time for first chunk latency calculation
|
|
|
+ user_data.record_start_time() # Record start time for first chunk latency calculation
|
|
|
|
|
|
# Establish stream
|
|
|
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
|
|
@@ -360,11 +365,11 @@ def run_sync_streaming_inference(
|
|
|
audios = []
|
|
|
while True:
|
|
|
try:
|
|
|
- result = user_data._completed_requests.get() # Add timeout
|
|
|
+ result = user_data._completed_requests.get() # Add timeout
|
|
|
if isinstance(result, InferenceServerException):
|
|
|
print(f"Received InferenceServerException: {result}")
|
|
|
sync_triton_client.stop_stream()
|
|
|
- return None, None, None # Indicate error
|
|
|
+ return None, None, None # Indicate error
|
|
|
# Get response metadata
|
|
|
response = result.get_response()
|
|
|
final = response.parameters["triton_final_response"].bool_param
|
|
|
@@ -372,15 +377,15 @@ def run_sync_streaming_inference(
|
|
|
break
|
|
|
|
|
|
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
|
|
- if audio_chunk.size > 0: # Only append non-empty chunks
|
|
|
- audios.append(audio_chunk)
|
|
|
+ if audio_chunk.size > 0: # Only append non-empty chunks
|
|
|
+ audios.append(audio_chunk)
|
|
|
else:
|
|
|
print("Warning: received empty audio chunk.")
|
|
|
|
|
|
except queue.Empty:
|
|
|
print(f"Timeout waiting for response for request id {request_id}")
|
|
|
sync_triton_client.stop_stream()
|
|
|
- return None, None, None # Indicate error
|
|
|
+ return None, None, None # Indicate error
|
|
|
|
|
|
sync_triton_client.stop_stream()
|
|
|
end_time_total = time.time()
|
|
|
@@ -398,19 +403,19 @@ def run_sync_streaming_inference(
|
|
|
# Simplified reconstruction based on client_grpc_streaming.py
|
|
|
if not audios:
|
|
|
print("Warning: No audio chunks received.")
|
|
|
- reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
|
|
+ reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
|
|
elif len(audios) == 1:
|
|
|
reconstructed_audio = audios[0]
|
|
|
else:
|
|
|
- reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
|
|
+ reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
|
|
for i in range(1, len(audios)):
|
|
|
- # Cross-fade section
|
|
|
- cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
|
|
- audios[i - 1][-cross_fade_samples:] * fade_out)
|
|
|
- # Middle section of the current chunk
|
|
|
- middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
|
|
- # Concatenate
|
|
|
- reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
|
|
+ # Cross-fade section
|
|
|
+ cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
|
|
+ audios[i - 1][-cross_fade_samples:] * fade_out)
|
|
|
+ # Middle section of the current chunk
|
|
|
+ middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
|
|
+ # Concatenate
|
|
|
+ reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
|
|
# Add the last part of the final chunk
|
|
|
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
|
|
|
|
|
@@ -421,11 +426,11 @@ def run_sync_streaming_inference(
|
|
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
|
|
else:
|
|
|
print("Warning: No audio chunks received or reconstructed.")
|
|
|
- actual_duration = 0 # Set duration to 0 if no audio
|
|
|
+ actual_duration = 0 # Set duration to 0 if no audio
|
|
|
|
|
|
else:
|
|
|
- print("Warning: No audio chunks received.")
|
|
|
- actual_duration = 0
|
|
|
+ print("Warning: No audio chunks received.")
|
|
|
+ actual_duration = 0
|
|
|
|
|
|
return total_request_latency, first_chunk_latency, actual_duration
|
|
|
|
|
|
@@ -433,7 +438,7 @@ def run_sync_streaming_inference(
|
|
|
async def send_streaming(
|
|
|
manifest_item_list: list,
|
|
|
name: str,
|
|
|
- server_url: str, # Changed from sync_triton_client
|
|
|
+ server_url: str, # Changed from sync_triton_client
|
|
|
protocol_client: types.ModuleType,
|
|
|
log_interval: int,
|
|
|
model_name: str,
|
|
|
@@ -445,11 +450,11 @@ async def send_streaming(
|
|
|
total_duration = 0.0
|
|
|
latency_data = []
|
|
|
task_id = int(name[5:])
|
|
|
- sync_triton_client = None # Initialize client variable
|
|
|
+ sync_triton_client = None # Initialize client variable
|
|
|
|
|
|
- try: # Wrap in try...finally to ensure client closing
|
|
|
+ try: # Wrap in try...finally to ensure client closing
|
|
|
print(f"{name}: Initializing sync client for streaming...")
|
|
|
- sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
|
|
+ sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
|
|
|
|
|
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
|
|
for i, item in enumerate(manifest_item_list):
|
|
|
@@ -491,8 +496,7 @@ async def send_streaming(
|
|
|
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
|
|
total_duration += actual_duration
|
|
|
else:
|
|
|
- print(f"{name}: Item {i} failed.")
|
|
|
-
|
|
|
+ print(f"{name}: Item {i} failed.")
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
|
|
@@ -501,8 +505,7 @@ async def send_streaming(
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
|
|
|
-
|
|
|
- finally: # Ensure client is closed
|
|
|
+ finally: # Ensure client is closed
|
|
|
if sync_triton_client:
|
|
|
try:
|
|
|
print(f"{name}: Closing sync client...")
|
|
|
@@ -510,10 +513,10 @@ async def send_streaming(
|
|
|
except Exception as e:
|
|
|
print(f"{name}: Error closing sync client: {e}")
|
|
|
|
|
|
-
|
|
|
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
|
|
return total_duration, latency_data
|
|
|
|
|
|
+
|
|
|
async def send(
|
|
|
manifest_item_list: list,
|
|
|
name: str,
|
|
|
@@ -605,6 +608,7 @@ def split_data(data, k):
|
|
|
|
|
|
return result
|
|
|
|
|
|
+
|
|
|
async def main():
|
|
|
args = get_args()
|
|
|
url = f"{args.server_addr}:{args.server_port}"
|
|
|
@@ -622,7 +626,7 @@ async def main():
|
|
|
# Use the sync client for streaming tasks, handled via asyncio.to_thread
|
|
|
# We will create one sync client instance PER TASK inside send_streaming.
|
|
|
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
|
|
|
- protocol_client = grpcclient_sync # protocol client for input prep
|
|
|
+ protocol_client = grpcclient_sync # protocol client for input prep
|
|
|
else:
|
|
|
raise ValueError(f"Invalid mode: {args.mode}")
|
|
|
# --- End Client Initialization ---
|
|
|
@@ -682,11 +686,11 @@ async def main():
|
|
|
)
|
|
|
)
|
|
|
elif args.mode == "streaming":
|
|
|
- task = asyncio.create_task(
|
|
|
+ task = asyncio.create_task(
|
|
|
send_streaming(
|
|
|
manifest_item_list[i],
|
|
|
name=f"task-{i}",
|
|
|
- server_url=url, # Pass URL instead of client
|
|
|
+ server_url=url, # Pass URL instead of client
|
|
|
protocol_client=protocol_client,
|
|
|
log_interval=args.log_interval,
|
|
|
model_name=args.model_name,
|
|
|
@@ -709,16 +713,15 @@ async def main():
|
|
|
for ans in ans_list:
|
|
|
if ans:
|
|
|
total_duration += ans[0]
|
|
|
- latency_data.extend(ans[1]) # Use extend for list of lists
|
|
|
+ latency_data.extend(ans[1]) # Use extend for list of lists
|
|
|
else:
|
|
|
- print("Warning: A task returned None, possibly due to an error.")
|
|
|
-
|
|
|
+ print("Warning: A task returned None, possibly due to an error.")
|
|
|
|
|
|
if total_duration == 0:
|
|
|
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
|
|
rtf = float('inf')
|
|
|
else:
|
|
|
- rtf = elapsed / total_duration
|
|
|
+ rtf = elapsed / total_duration
|
|
|
|
|
|
s = f"Mode: {args.mode}\n"
|
|
|
s += f"RTF: {rtf:.4f}\n"
|
|
|
@@ -759,7 +762,7 @@ async def main():
|
|
|
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
|
|
|
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
|
|
|
else:
|
|
|
- s += "No total request latency data collected.\n"
|
|
|
+ s += "No total request latency data collected.\n"
|
|
|
|
|
|
s += "\n--- First Chunk Latency ---\n"
|
|
|
if first_chunk_latency_list:
|
|
|
@@ -772,7 +775,7 @@ async def main():
|
|
|
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
|
|
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
|
|
else:
|
|
|
- s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
|
|
+ s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
|
|
else:
|
|
|
s += "No latency data collected.\n"
|
|
|
# --- End Statistics Reporting ---
|
|
|
@@ -785,7 +788,7 @@ async def main():
|
|
|
elif args.reference_audio:
|
|
|
name = Path(args.reference_audio).stem
|
|
|
else:
|
|
|
- name = "results" # Default name if no manifest/split/audio provided
|
|
|
+ name = "results" # Default name if no manifest/split/audio provided
|
|
|
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
|
|
f.write(s)
|
|
|
|