client_grpc.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834
  1. # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
  2. # 2023 Nvidia (authors: Yuekai Zhang)
  3. # 2023 Recurrent.ai (authors: Songtao Shi)
  4. # See LICENSE for clarification regarding multiple authors
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """
  18. This script supports to load dataset from huggingface and sends it to the server
  19. for decoding, in parallel.
  20. Usage:
  21. num_task=2
  22. # For offline F5-TTS
  23. python3 client_grpc.py \
  24. --server-addr localhost \
  25. --model-name f5_tts \
  26. --num-tasks $num_task \
  27. --huggingface-dataset yuekai/seed_tts \
  28. --split-name test_zh \
  29. --log-dir ./log_concurrent_tasks_${num_task}
  30. # For offline Spark-TTS-0.5B
  31. python3 client_grpc.py \
  32. --server-addr localhost \
  33. --model-name spark_tts \
  34. --num-tasks $num_task \
  35. --huggingface-dataset yuekai/seed_tts \
  36. --split-name wenetspeech4tts \
  37. --log-dir ./log_concurrent_tasks_${num_task}
  38. """
  39. import argparse
  40. import asyncio
  41. import json
  42. import queue # Added
  43. import uuid # Added
  44. import functools # Added
  45. import os
  46. import time
  47. import types
  48. from pathlib import Path
  49. import numpy as np
  50. import soundfile as sf
  51. import tritonclient
  52. import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
  53. import tritonclient.grpc as grpcclient_sync # Added sync client import
  54. from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
  55. # --- Added UserData and callback ---
  56. class UserData:
  57. def __init__(self):
  58. self._completed_requests = queue.Queue()
  59. self._first_chunk_time = None
  60. self._start_time = None
  61. def record_start_time(self):
  62. self._start_time = time.time()
  63. def get_first_chunk_latency(self):
  64. if self._first_chunk_time and self._start_time:
  65. return self._first_chunk_time - self._start_time
  66. return None
  67. def callback(user_data, result, error):
  68. if user_data._first_chunk_time is None and not error:
  69. user_data._first_chunk_time = time.time() # Record time of first successful chunk
  70. if error:
  71. user_data._completed_requests.put(error)
  72. else:
  73. user_data._completed_requests.put(result)
  74. # --- End Added UserData and callback ---
  75. def write_triton_stats(stats, summary_file):
  76. with open(summary_file, "w") as summary_f:
  77. model_stats = stats["model_stats"]
  78. # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
  79. summary_f.write(
  80. "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
  81. )
  82. summary_f.write("To learn more about the log, please refer to: \n")
  83. summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
  84. summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
  85. summary_f.write(
  86. "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
  87. )
  88. summary_f.write(
  89. "However, there is a trade-off between the increased queue time and the increased batch size. \n"
  90. )
  91. summary_f.write(
  92. "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
  93. )
  94. summary_f.write(
  95. "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
  96. )
  97. for model_state in model_stats:
  98. if "last_inference" not in model_state:
  99. continue
  100. summary_f.write(f"model name is {model_state['name']} \n")
  101. model_inference_stats = model_state["inference_stats"]
  102. total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
  103. total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
  104. total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
  105. total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
  106. summary_f.write(
  107. f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
  108. )
  109. model_batch_stats = model_state["batch_stats"]
  110. for batch in model_batch_stats:
  111. batch_size = int(batch["batch_size"])
  112. compute_input = batch["compute_input"]
  113. compute_output = batch["compute_output"]
  114. compute_infer = batch["compute_infer"]
  115. batch_count = int(compute_infer["count"])
  116. assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
  117. compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
  118. compute_input_time_ms = int(compute_input["ns"]) / 1e6
  119. compute_output_time_ms = int(compute_output["ns"]) / 1e6
  120. summary_f.write(
  121. f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
  122. )
  123. summary_f.write(
  124. f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
  125. )
  126. summary_f.write(
  127. f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
  128. )
  129. def get_args():
  130. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  131. parser.add_argument(
  132. "--server-addr",
  133. type=str,
  134. default="localhost",
  135. help="Address of the server",
  136. )
  137. parser.add_argument(
  138. "--server-port",
  139. type=int,
  140. default=8001,
  141. help="Grpc port of the triton server, default is 8001",
  142. )
  143. parser.add_argument(
  144. "--reference-audio",
  145. type=str,
  146. default=None,
  147. help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
  148. )
  149. parser.add_argument(
  150. "--reference-text",
  151. type=str,
  152. default="",
  153. help="",
  154. )
  155. parser.add_argument(
  156. "--target-text",
  157. type=str,
  158. default="",
  159. help="",
  160. )
  161. parser.add_argument(
  162. "--huggingface-dataset",
  163. type=str,
  164. default="yuekai/seed_tts",
  165. help="dataset name in huggingface dataset hub",
  166. )
  167. parser.add_argument(
  168. "--split-name",
  169. type=str,
  170. default="wenetspeech4tts",
  171. choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
  172. help="dataset split name, default is 'test'",
  173. )
  174. parser.add_argument(
  175. "--manifest-path",
  176. type=str,
  177. default=None,
  178. help="Path to the manifest dir which includes wav.scp trans.txt files.",
  179. )
  180. parser.add_argument(
  181. "--model-name",
  182. type=str,
  183. default="f5_tts",
  184. choices=[
  185. "f5_tts",
  186. "spark_tts",
  187. "cosyvoice2"],
  188. help="triton model_repo module name to request",
  189. )
  190. parser.add_argument(
  191. "--num-tasks",
  192. type=int,
  193. default=1,
  194. help="Number of concurrent tasks for sending",
  195. )
  196. parser.add_argument(
  197. "--log-interval",
  198. type=int,
  199. default=5,
  200. help="Controls how frequently we print the log.",
  201. )
  202. parser.add_argument(
  203. "--compute-wer",
  204. action="store_true",
  205. default=False,
  206. help="""True to compute WER.
  207. """,
  208. )
  209. parser.add_argument(
  210. "--log-dir",
  211. type=str,
  212. required=False,
  213. default="./tmp",
  214. help="log directory",
  215. )
  216. # --- Added arguments ---
  217. parser.add_argument(
  218. "--mode",
  219. type=str,
  220. default="offline",
  221. choices=["offline", "streaming"],
  222. help="Select offline or streaming benchmark mode."
  223. )
  224. parser.add_argument(
  225. "--chunk-overlap-duration",
  226. type=float,
  227. default=0.1,
  228. help="Chunk overlap duration for streaming reconstruction (in seconds)."
  229. )
  230. # --- End Added arguments ---
  231. return parser.parse_args()
  232. def load_audio(wav_path, target_sample_rate=16000):
  233. assert target_sample_rate == 16000, "hard coding in server"
  234. if isinstance(wav_path, dict):
  235. waveform = wav_path["array"]
  236. sample_rate = wav_path["sampling_rate"]
  237. else:
  238. waveform, sample_rate = sf.read(wav_path)
  239. if sample_rate != target_sample_rate:
  240. from scipy.signal import resample
  241. num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
  242. waveform = resample(waveform, num_samples)
  243. return waveform, target_sample_rate
  244. def prepare_request_input_output(
  245. protocol_client, # Can be grpcclient_aio or grpcclient_sync
  246. waveform,
  247. reference_text,
  248. target_text,
  249. sample_rate=16000,
  250. padding_duration: int = None # Optional padding for offline mode
  251. ):
  252. """Prepares inputs for Triton inference (offline or streaming)."""
  253. assert len(waveform.shape) == 1, "waveform should be 1D"
  254. lengths = np.array([[len(waveform)]], dtype=np.int32)
  255. # Apply padding only if padding_duration is provided (for offline)
  256. if padding_duration:
  257. duration = len(waveform) / sample_rate
  258. # Estimate target duration based on text length ratio (crude estimation)
  259. # Avoid division by zero if reference_text is empty
  260. if reference_text:
  261. estimated_target_duration = duration / len(reference_text) * len(target_text)
  262. else:
  263. estimated_target_duration = duration # Assume target duration similar to reference if no text
  264. # Calculate required samples based on estimated total duration
  265. required_total_samples = padding_duration * sample_rate * (
  266. (int(estimated_target_duration + duration) // padding_duration) + 1
  267. )
  268. samples = np.zeros((1, required_total_samples), dtype=np.float32)
  269. samples[0, : len(waveform)] = waveform
  270. else:
  271. # No padding for streaming or if padding_duration is None
  272. samples = waveform.reshape(1, -1).astype(np.float32)
  273. # Common input creation logic
  274. inputs = [
  275. protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
  276. protocol_client.InferInput(
  277. "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
  278. ),
  279. protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
  280. protocol_client.InferInput("target_text", [1, 1], "BYTES"),
  281. ]
  282. inputs[0].set_data_from_numpy(samples)
  283. inputs[1].set_data_from_numpy(lengths)
  284. input_data_numpy = np.array([reference_text], dtype=object)
  285. input_data_numpy = input_data_numpy.reshape((1, 1))
  286. inputs[2].set_data_from_numpy(input_data_numpy)
  287. input_data_numpy = np.array([target_text], dtype=object)
  288. input_data_numpy = input_data_numpy.reshape((1, 1))
  289. inputs[3].set_data_from_numpy(input_data_numpy)
  290. outputs = [protocol_client.InferRequestedOutput("waveform")]
  291. return inputs, outputs
  292. def run_sync_streaming_inference(
  293. sync_triton_client: tritonclient.grpc.InferenceServerClient,
  294. model_name: str,
  295. inputs: list,
  296. outputs: list,
  297. request_id: str,
  298. user_data: UserData,
  299. chunk_overlap_duration: float,
  300. save_sample_rate: int,
  301. audio_save_path: str,
  302. ):
  303. """Helper function to run the blocking sync streaming call."""
  304. start_time_total = time.time()
  305. user_data.record_start_time() # Record start time for first chunk latency calculation
  306. # Establish stream
  307. sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
  308. # Send request
  309. sync_triton_client.async_stream_infer(
  310. model_name,
  311. inputs,
  312. request_id=request_id,
  313. outputs=outputs,
  314. enable_empty_final_response=True,
  315. )
  316. # Process results
  317. audios = []
  318. while True:
  319. try:
  320. result = user_data._completed_requests.get() # Add timeout
  321. if isinstance(result, InferenceServerException):
  322. print(f"Received InferenceServerException: {result}")
  323. sync_triton_client.stop_stream()
  324. return None, None, None # Indicate error
  325. # Get response metadata
  326. response = result.get_response()
  327. final = response.parameters["triton_final_response"].bool_param
  328. if final is True:
  329. break
  330. audio_chunk = result.as_numpy("waveform").reshape(-1)
  331. if audio_chunk.size > 0: # Only append non-empty chunks
  332. audios.append(audio_chunk)
  333. else:
  334. print("Warning: received empty audio chunk.")
  335. except queue.Empty:
  336. print(f"Timeout waiting for response for request id {request_id}")
  337. sync_triton_client.stop_stream()
  338. return None, None, None # Indicate error
  339. sync_triton_client.stop_stream()
  340. end_time_total = time.time()
  341. total_request_latency = end_time_total - start_time_total
  342. first_chunk_latency = user_data.get_first_chunk_latency()
  343. # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
  344. actual_duration = 0
  345. if audios:
  346. cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
  347. fade_out = np.linspace(1, 0, cross_fade_samples)
  348. fade_in = np.linspace(0, 1, cross_fade_samples)
  349. reconstructed_audio = None
  350. # Simplified reconstruction based on client_grpc_streaming.py
  351. if not audios:
  352. print("Warning: No audio chunks received.")
  353. reconstructed_audio = np.array([], dtype=np.float32) # Empty array
  354. elif len(audios) == 1:
  355. reconstructed_audio = audios[0]
  356. else:
  357. reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
  358. for i in range(1, len(audios)):
  359. # Cross-fade section
  360. cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
  361. audios[i - 1][-cross_fade_samples:] * fade_out)
  362. # Middle section of the current chunk
  363. middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
  364. # Concatenate
  365. reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
  366. # Add the last part of the final chunk
  367. reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
  368. if reconstructed_audio is not None and reconstructed_audio.size > 0:
  369. actual_duration = len(reconstructed_audio) / save_sample_rate
  370. # Save reconstructed audio
  371. os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
  372. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  373. else:
  374. print("Warning: No audio chunks received or reconstructed.")
  375. actual_duration = 0 # Set duration to 0 if no audio
  376. else:
  377. print("Warning: No audio chunks received.")
  378. actual_duration = 0
  379. return total_request_latency, first_chunk_latency, actual_duration
  380. async def send_streaming(
  381. manifest_item_list: list,
  382. name: str,
  383. server_url: str, # Changed from sync_triton_client
  384. protocol_client: types.ModuleType,
  385. log_interval: int,
  386. model_name: str,
  387. audio_save_dir: str = "./",
  388. save_sample_rate: int = 16000,
  389. chunk_overlap_duration: float = 0.1,
  390. padding_duration: int = None,
  391. ):
  392. total_duration = 0.0
  393. latency_data = []
  394. task_id = int(name[5:])
  395. sync_triton_client = None # Initialize client variable
  396. try: # Wrap in try...finally to ensure client closing
  397. print(f"{name}: Initializing sync client for streaming...")
  398. sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
  399. print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
  400. for i, item in enumerate(manifest_item_list):
  401. if i % log_interval == 0:
  402. print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
  403. try:
  404. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  405. reference_text, target_text = item["reference_text"], item["target_text"]
  406. inputs, outputs = prepare_request_input_output(
  407. protocol_client,
  408. waveform,
  409. reference_text,
  410. target_text,
  411. sample_rate,
  412. padding_duration=padding_duration
  413. )
  414. request_id = str(uuid.uuid4())
  415. user_data = UserData()
  416. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  417. total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
  418. run_sync_streaming_inference,
  419. sync_triton_client,
  420. model_name,
  421. inputs,
  422. outputs,
  423. request_id,
  424. user_data,
  425. chunk_overlap_duration,
  426. save_sample_rate,
  427. audio_save_path
  428. )
  429. if total_request_latency is not None:
  430. print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
  431. latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
  432. total_duration += actual_duration
  433. else:
  434. print(f"{name}: Item {i} failed.")
  435. except FileNotFoundError:
  436. print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
  437. except Exception as e:
  438. print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
  439. import traceback
  440. traceback.print_exc()
  441. finally: # Ensure client is closed
  442. if sync_triton_client:
  443. try:
  444. print(f"{name}: Closing sync client...")
  445. sync_triton_client.close()
  446. except Exception as e:
  447. print(f"{name}: Error closing sync client: {e}")
  448. print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
  449. return total_duration, latency_data
  450. async def send(
  451. manifest_item_list: list,
  452. name: str,
  453. triton_client: tritonclient.grpc.aio.InferenceServerClient,
  454. protocol_client: types.ModuleType,
  455. log_interval: int,
  456. model_name: str,
  457. padding_duration: int = None,
  458. audio_save_dir: str = "./",
  459. save_sample_rate: int = 16000,
  460. ):
  461. total_duration = 0.0
  462. latency_data = []
  463. task_id = int(name[5:])
  464. print(f"manifest_item_list: {manifest_item_list}")
  465. for i, item in enumerate(manifest_item_list):
  466. if i % log_interval == 0:
  467. print(f"{name}: {i}/{len(manifest_item_list)}")
  468. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  469. reference_text, target_text = item["reference_text"], item["target_text"]
  470. inputs, outputs = prepare_request_input_output(
  471. protocol_client,
  472. waveform,
  473. reference_text,
  474. target_text,
  475. sample_rate,
  476. padding_duration=padding_duration
  477. )
  478. sequence_id = 100000000 + i + task_id * 10
  479. start = time.time()
  480. response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
  481. audio = response.as_numpy("waveform").reshape(-1)
  482. actual_duration = len(audio) / save_sample_rate
  483. end = time.time() - start
  484. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  485. sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
  486. latency_data.append((end, actual_duration))
  487. total_duration += actual_duration
  488. return total_duration, latency_data
  489. def load_manifests(manifest_path):
  490. with open(manifest_path, "r") as f:
  491. manifest_list = []
  492. for line in f:
  493. assert len(line.strip().split("|")) == 4
  494. utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
  495. utt = Path(utt).stem
  496. # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
  497. if not os.path.isabs(prompt_wav):
  498. prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
  499. manifest_list.append(
  500. {
  501. "audio_filepath": prompt_wav,
  502. "reference_text": prompt_text,
  503. "target_text": gt_text,
  504. "target_audio_path": utt,
  505. }
  506. )
  507. return manifest_list
  508. def split_data(data, k):
  509. n = len(data)
  510. if n < k:
  511. print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
  512. k = n
  513. quotient = n // k
  514. remainder = n % k
  515. result = []
  516. start = 0
  517. for i in range(k):
  518. if i < remainder:
  519. end = start + quotient + 1
  520. else:
  521. end = start + quotient
  522. result.append(data[start:end])
  523. start = end
  524. return result
  525. async def main():
  526. args = get_args()
  527. url = f"{args.server_addr}:{args.server_port}"
  528. # --- Client Initialization based on mode ---
  529. triton_client = None
  530. protocol_client = None
  531. if args.mode == "offline":
  532. print("Initializing gRPC client for offline mode...")
  533. # Use the async client for offline tasks
  534. triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  535. protocol_client = grpcclient_aio
  536. elif args.mode == "streaming":
  537. print("Initializing gRPC client for streaming mode...")
  538. # Use the sync client for streaming tasks, handled via asyncio.to_thread
  539. # We will create one sync client instance PER TASK inside send_streaming.
  540. # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
  541. protocol_client = grpcclient_sync # protocol client for input prep
  542. else:
  543. raise ValueError(f"Invalid mode: {args.mode}")
  544. # --- End Client Initialization ---
  545. if args.reference_audio:
  546. args.num_tasks = 1
  547. args.log_interval = 1
  548. manifest_item_list = [
  549. {
  550. "reference_text": args.reference_text,
  551. "target_text": args.target_text,
  552. "audio_filepath": args.reference_audio,
  553. "target_audio_path": "test",
  554. }
  555. ]
  556. elif args.huggingface_dataset:
  557. import datasets
  558. dataset = datasets.load_dataset(
  559. args.huggingface_dataset,
  560. split=args.split_name,
  561. trust_remote_code=True,
  562. )
  563. manifest_item_list = []
  564. for i in range(len(dataset)):
  565. manifest_item_list.append(
  566. {
  567. "audio_filepath": dataset[i]["prompt_audio"],
  568. "reference_text": dataset[i]["prompt_text"],
  569. "target_audio_path": dataset[i]["id"],
  570. "target_text": dataset[i]["target_text"],
  571. }
  572. )
  573. else:
  574. manifest_item_list = load_manifests(args.manifest_path)
  575. num_tasks = min(args.num_tasks, len(manifest_item_list))
  576. manifest_item_list = split_data(manifest_item_list, num_tasks)
  577. os.makedirs(args.log_dir, exist_ok=True)
  578. tasks = []
  579. start_time = time.time()
  580. for i in range(num_tasks):
  581. # --- Task Creation based on mode ---
  582. if args.mode == "offline":
  583. task = asyncio.create_task(
  584. send(
  585. manifest_item_list[i],
  586. name=f"task-{i}",
  587. triton_client=triton_client,
  588. protocol_client=protocol_client,
  589. log_interval=args.log_interval,
  590. model_name=args.model_name,
  591. audio_save_dir=args.log_dir,
  592. padding_duration=1,
  593. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  594. )
  595. )
  596. elif args.mode == "streaming":
  597. task = asyncio.create_task(
  598. send_streaming(
  599. manifest_item_list[i],
  600. name=f"task-{i}",
  601. server_url=url, # Pass URL instead of client
  602. protocol_client=protocol_client,
  603. log_interval=args.log_interval,
  604. model_name=args.model_name,
  605. audio_save_dir=args.log_dir,
  606. padding_duration=10,
  607. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  608. chunk_overlap_duration=args.chunk_overlap_duration,
  609. )
  610. )
  611. # --- End Task Creation ---
  612. tasks.append(task)
  613. ans_list = await asyncio.gather(*tasks)
  614. end_time = time.time()
  615. elapsed = end_time - start_time
  616. total_duration = 0.0
  617. latency_data = []
  618. for ans in ans_list:
  619. if ans:
  620. total_duration += ans[0]
  621. latency_data.extend(ans[1]) # Use extend for list of lists
  622. else:
  623. print("Warning: A task returned None, possibly due to an error.")
  624. if total_duration == 0:
  625. print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
  626. rtf = float('inf')
  627. else:
  628. rtf = elapsed / total_duration
  629. s = f"Mode: {args.mode}\n"
  630. s += f"RTF: {rtf:.4f}\n"
  631. s += f"total_duration: {total_duration:.3f} seconds\n"
  632. s += f"({total_duration / 3600:.2f} hours)\n"
  633. s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
  634. # --- Statistics Reporting based on mode ---
  635. if latency_data:
  636. if args.mode == "offline":
  637. # Original offline latency calculation
  638. latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
  639. if latency_list:
  640. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  641. latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
  642. s += f"latency_variance: {latency_variance:.2f}\n"
  643. s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
  644. s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
  645. s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
  646. s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
  647. s += f"average_latency_ms: {latency_ms:.2f}\n"
  648. else:
  649. s += "No latency data collected for offline mode.\n"
  650. elif args.mode == "streaming":
  651. # Calculate stats for total request latency and first chunk latency
  652. total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
  653. first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
  654. s += "\n--- Total Request Latency ---\n"
  655. if total_latency_list:
  656. avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
  657. variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
  658. s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
  659. s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
  660. s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
  661. s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
  662. s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
  663. s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
  664. else:
  665. s += "No total request latency data collected.\n"
  666. s += "\n--- First Chunk Latency ---\n"
  667. if first_chunk_latency_list:
  668. avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
  669. variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
  670. s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
  671. s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
  672. s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
  673. s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
  674. s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
  675. s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
  676. else:
  677. s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
  678. else:
  679. s += "No latency data collected.\n"
  680. # --- End Statistics Reporting ---
  681. print(s)
  682. if args.manifest_path:
  683. name = Path(args.manifest_path).stem
  684. elif args.split_name:
  685. name = args.split_name
  686. elif args.reference_audio:
  687. name = Path(args.reference_audio).stem
  688. else:
  689. name = "results" # Default name if no manifest/split/audio provided
  690. with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
  691. f.write(s)
  692. # --- Statistics Fetching using temporary Async Client ---
  693. # Use a separate async client for fetching stats regardless of mode
  694. stats_client = None
  695. try:
  696. print("Initializing temporary async client for fetching stats...")
  697. stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  698. print("Fetching inference statistics...")
  699. # Fetching for all models, filtering might be needed depending on server setup
  700. stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
  701. print("Fetching model config...")
  702. metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
  703. write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
  704. with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
  705. json.dump(metadata, f, indent=4)
  706. except Exception as e:
  707. print(f"Could not retrieve statistics or config: {e}")
  708. finally:
  709. if stats_client:
  710. try:
  711. print("Closing temporary async stats client...")
  712. await stats_client.close()
  713. except Exception as e:
  714. print(f"Error closing async stats client: {e}")
  715. # --- End Statistics Fetching ---
  716. if __name__ == "__main__":
  717. # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
  718. async def run_main():
  719. try:
  720. await main()
  721. except Exception as e:
  722. print(f"An error occurred in main: {e}")
  723. import traceback
  724. traceback.print_exc()
  725. asyncio.run(run_main())