client_grpc.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. #!/usr/bin/env python3
  2. # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
  3. # 2023 Nvidia (authors: Yuekai Zhang)
  4. # 2023 Recurrent.ai (authors: Songtao Shi)
  5. # See LICENSE for clarification regarding multiple authors
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """
  19. This script supports to load dataset from huggingface and sends it to the server
  20. for decoding, in parallel.
  21. Usage:
  22. num_task=2
  23. # For offline F5-TTS
  24. python3 client_grpc.py \
  25. --server-addr localhost \
  26. --model-name f5_tts \
  27. --num-tasks $num_task \
  28. --huggingface-dataset yuekai/seed_tts \
  29. --split-name test_zh \
  30. --log-dir ./log_concurrent_tasks_${num_task}
  31. # For offline Spark-TTS-0.5B
  32. python3 client_grpc.py \
  33. --server-addr localhost \
  34. --model-name spark_tts \
  35. --num-tasks $num_task \
  36. --huggingface-dataset yuekai/seed_tts \
  37. --split-name wenetspeech4tts \
  38. --log-dir ./log_concurrent_tasks_${num_task}
  39. """
  40. import argparse
  41. import asyncio
  42. import json
  43. import queue # Added
  44. import uuid # Added
  45. import functools # Added
  46. import os
  47. import time
  48. import types
  49. from pathlib import Path
  50. import numpy as np
  51. import soundfile as sf
  52. import tritonclient
  53. import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
  54. import tritonclient.grpc as grpcclient_sync # Added sync client import
  55. from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
  56. # --- Added UserData and callback ---
  57. class UserData:
  58. def __init__(self):
  59. self._completed_requests = queue.Queue()
  60. self._first_chunk_time = None
  61. self._start_time = None
  62. def record_start_time(self):
  63. self._start_time = time.time()
  64. def get_first_chunk_latency(self):
  65. if self._first_chunk_time and self._start_time:
  66. return self._first_chunk_time - self._start_time
  67. return None
  68. def callback(user_data, result, error):
  69. if user_data._first_chunk_time is None and not error:
  70. user_data._first_chunk_time = time.time() # Record time of first successful chunk
  71. if error:
  72. user_data._completed_requests.put(error)
  73. else:
  74. user_data._completed_requests.put(result)
  75. # --- End Added UserData and callback ---
  76. def write_triton_stats(stats, summary_file):
  77. with open(summary_file, "w") as summary_f:
  78. model_stats = stats["model_stats"]
  79. # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
  80. summary_f.write(
  81. "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
  82. )
  83. summary_f.write("To learn more about the log, please refer to: \n")
  84. summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
  85. summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
  86. summary_f.write(
  87. "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
  88. )
  89. summary_f.write(
  90. "However, there is a trade-off between the increased queue time and the increased batch size. \n"
  91. )
  92. summary_f.write(
  93. "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
  94. )
  95. summary_f.write(
  96. "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
  97. )
  98. for model_state in model_stats:
  99. if "last_inference" not in model_state:
  100. continue
  101. summary_f.write(f"model name is {model_state['name']} \n")
  102. model_inference_stats = model_state["inference_stats"]
  103. total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
  104. total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
  105. total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
  106. total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
  107. summary_f.write(
  108. f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
  109. )
  110. model_batch_stats = model_state["batch_stats"]
  111. for batch in model_batch_stats:
  112. batch_size = int(batch["batch_size"])
  113. compute_input = batch["compute_input"]
  114. compute_output = batch["compute_output"]
  115. compute_infer = batch["compute_infer"]
  116. batch_count = int(compute_infer["count"])
  117. assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
  118. compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
  119. compute_input_time_ms = int(compute_input["ns"]) / 1e6
  120. compute_output_time_ms = int(compute_output["ns"]) / 1e6
  121. summary_f.write(
  122. f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
  123. )
  124. summary_f.write(
  125. f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
  126. )
  127. summary_f.write(
  128. f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
  129. )
  130. def get_args():
  131. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  132. parser.add_argument(
  133. "--server-addr",
  134. type=str,
  135. default="localhost",
  136. help="Address of the server",
  137. )
  138. parser.add_argument(
  139. "--server-port",
  140. type=int,
  141. default=8001,
  142. help="Grpc port of the triton server, default is 8001",
  143. )
  144. parser.add_argument(
  145. "--reference-audio",
  146. type=str,
  147. default=None,
  148. help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
  149. )
  150. parser.add_argument(
  151. "--reference-text",
  152. type=str,
  153. default="",
  154. help="",
  155. )
  156. parser.add_argument(
  157. "--target-text",
  158. type=str,
  159. default="",
  160. help="",
  161. )
  162. parser.add_argument(
  163. "--huggingface-dataset",
  164. type=str,
  165. default="yuekai/seed_tts",
  166. help="dataset name in huggingface dataset hub",
  167. )
  168. parser.add_argument(
  169. "--split-name",
  170. type=str,
  171. default="wenetspeech4tts",
  172. choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
  173. help="dataset split name, default is 'test'",
  174. )
  175. parser.add_argument(
  176. "--manifest-path",
  177. type=str,
  178. default=None,
  179. help="Path to the manifest dir which includes wav.scp trans.txt files.",
  180. )
  181. parser.add_argument(
  182. "--model-name",
  183. type=str,
  184. default="f5_tts",
  185. choices=["f5_tts", "spark_tts", "cosyvoice2"],
  186. help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
  187. )
  188. parser.add_argument(
  189. "--num-tasks",
  190. type=int,
  191. default=1,
  192. help="Number of concurrent tasks for sending",
  193. )
  194. parser.add_argument(
  195. "--log-interval",
  196. type=int,
  197. default=5,
  198. help="Controls how frequently we print the log.",
  199. )
  200. parser.add_argument(
  201. "--compute-wer",
  202. action="store_true",
  203. default=False,
  204. help="""True to compute WER.
  205. """,
  206. )
  207. parser.add_argument(
  208. "--log-dir",
  209. type=str,
  210. required=False,
  211. default="./tmp",
  212. help="log directory",
  213. )
  214. # --- Added arguments ---
  215. parser.add_argument(
  216. "--mode",
  217. type=str,
  218. default="offline",
  219. choices=["offline", "streaming"],
  220. help="Select offline or streaming benchmark mode."
  221. )
  222. parser.add_argument(
  223. "--chunk-overlap-duration",
  224. type=float,
  225. default=0.1,
  226. help="Chunk overlap duration for streaming reconstruction (in seconds)."
  227. )
  228. # --- End Added arguments ---
  229. return parser.parse_args()
  230. def load_audio(wav_path, target_sample_rate=16000):
  231. assert target_sample_rate == 16000, "hard coding in server"
  232. if isinstance(wav_path, dict):
  233. waveform = wav_path["array"]
  234. sample_rate = wav_path["sampling_rate"]
  235. else:
  236. waveform, sample_rate = sf.read(wav_path)
  237. if sample_rate != target_sample_rate:
  238. from scipy.signal import resample
  239. num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
  240. waveform = resample(waveform, num_samples)
  241. return waveform, target_sample_rate
  242. def prepare_request_input_output(
  243. protocol_client, # Can be grpcclient_aio or grpcclient_sync
  244. waveform,
  245. reference_text,
  246. target_text,
  247. sample_rate=16000,
  248. padding_duration: int = None # Optional padding for offline mode
  249. ):
  250. """Prepares inputs for Triton inference (offline or streaming)."""
  251. assert len(waveform.shape) == 1, "waveform should be 1D"
  252. lengths = np.array([[len(waveform)]], dtype=np.int32)
  253. # Apply padding only if padding_duration is provided (for offline)
  254. if padding_duration:
  255. duration = len(waveform) / sample_rate
  256. # Estimate target duration based on text length ratio (crude estimation)
  257. # Avoid division by zero if reference_text is empty
  258. if reference_text:
  259. estimated_target_duration = duration / len(reference_text) * len(target_text)
  260. else:
  261. estimated_target_duration = duration # Assume target duration similar to reference if no text
  262. # Calculate required samples based on estimated total duration
  263. required_total_samples = padding_duration * sample_rate * (
  264. (int(estimated_target_duration + duration) // padding_duration) + 1
  265. )
  266. samples = np.zeros((1, required_total_samples), dtype=np.float32)
  267. samples[0, : len(waveform)] = waveform
  268. else:
  269. # No padding for streaming or if padding_duration is None
  270. samples = waveform.reshape(1, -1).astype(np.float32)
  271. # Common input creation logic
  272. inputs = [
  273. protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
  274. protocol_client.InferInput(
  275. "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
  276. ),
  277. protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
  278. protocol_client.InferInput("target_text", [1, 1], "BYTES"),
  279. ]
  280. inputs[0].set_data_from_numpy(samples)
  281. inputs[1].set_data_from_numpy(lengths)
  282. input_data_numpy = np.array([reference_text], dtype=object)
  283. input_data_numpy = input_data_numpy.reshape((1, 1))
  284. inputs[2].set_data_from_numpy(input_data_numpy)
  285. input_data_numpy = np.array([target_text], dtype=object)
  286. input_data_numpy = input_data_numpy.reshape((1, 1))
  287. inputs[3].set_data_from_numpy(input_data_numpy)
  288. outputs = [protocol_client.InferRequestedOutput("waveform")]
  289. return inputs, outputs
  290. def run_sync_streaming_inference(
  291. sync_triton_client: tritonclient.grpc.InferenceServerClient,
  292. model_name: str,
  293. inputs: list,
  294. outputs: list,
  295. request_id: str,
  296. user_data: UserData,
  297. chunk_overlap_duration: float,
  298. save_sample_rate: int,
  299. audio_save_path: str,
  300. ):
  301. """Helper function to run the blocking sync streaming call."""
  302. start_time_total = time.time()
  303. user_data.record_start_time() # Record start time for first chunk latency calculation
  304. # Establish stream
  305. sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
  306. # Send request
  307. sync_triton_client.async_stream_infer(
  308. model_name,
  309. inputs,
  310. request_id=request_id,
  311. outputs=outputs,
  312. enable_empty_final_response=True,
  313. )
  314. # Process results
  315. audios = []
  316. while True:
  317. try:
  318. result = user_data._completed_requests.get() # Add timeout
  319. if isinstance(result, InferenceServerException):
  320. print(f"Received InferenceServerException: {result}")
  321. sync_triton_client.stop_stream()
  322. return None, None, None # Indicate error
  323. # Get response metadata
  324. response = result.get_response()
  325. final = response.parameters["triton_final_response"].bool_param
  326. if final is True:
  327. break
  328. audio_chunk = result.as_numpy("waveform").reshape(-1)
  329. if audio_chunk.size > 0: # Only append non-empty chunks
  330. audios.append(audio_chunk)
  331. else:
  332. print("Warning: received empty audio chunk.")
  333. except queue.Empty:
  334. print(f"Timeout waiting for response for request id {request_id}")
  335. sync_triton_client.stop_stream()
  336. return None, None, None # Indicate error
  337. sync_triton_client.stop_stream()
  338. end_time_total = time.time()
  339. total_request_latency = end_time_total - start_time_total
  340. first_chunk_latency = user_data.get_first_chunk_latency()
  341. # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
  342. actual_duration = 0
  343. if audios:
  344. cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
  345. fade_out = np.linspace(1, 0, cross_fade_samples)
  346. fade_in = np.linspace(0, 1, cross_fade_samples)
  347. reconstructed_audio = None
  348. # Simplified reconstruction based on client_grpc_streaming.py
  349. if not audios:
  350. print("Warning: No audio chunks received.")
  351. reconstructed_audio = np.array([], dtype=np.float32) # Empty array
  352. elif len(audios) == 1:
  353. reconstructed_audio = audios[0]
  354. else:
  355. reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
  356. for i in range(1, len(audios)):
  357. # Cross-fade section
  358. cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
  359. audios[i - 1][-cross_fade_samples:] * fade_out)
  360. # Middle section of the current chunk
  361. middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
  362. # Concatenate
  363. reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
  364. # Add the last part of the final chunk
  365. reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
  366. if reconstructed_audio is not None and reconstructed_audio.size > 0:
  367. actual_duration = len(reconstructed_audio) / save_sample_rate
  368. # Save reconstructed audio
  369. os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
  370. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  371. else:
  372. print("Warning: No audio chunks received or reconstructed.")
  373. actual_duration = 0 # Set duration to 0 if no audio
  374. else:
  375. print("Warning: No audio chunks received.")
  376. actual_duration = 0
  377. return total_request_latency, first_chunk_latency, actual_duration
  378. async def send_streaming(
  379. manifest_item_list: list,
  380. name: str,
  381. server_url: str, # Changed from sync_triton_client
  382. protocol_client: types.ModuleType,
  383. log_interval: int,
  384. model_name: str,
  385. audio_save_dir: str = "./",
  386. save_sample_rate: int = 16000,
  387. chunk_overlap_duration: float = 0.1,
  388. padding_duration: int = None,
  389. ):
  390. total_duration = 0.0
  391. latency_data = []
  392. task_id = int(name[5:])
  393. sync_triton_client = None # Initialize client variable
  394. try: # Wrap in try...finally to ensure client closing
  395. print(f"{name}: Initializing sync client for streaming...")
  396. sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
  397. print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
  398. for i, item in enumerate(manifest_item_list):
  399. if i % log_interval == 0:
  400. print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
  401. try:
  402. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  403. reference_text, target_text = item["reference_text"], item["target_text"]
  404. inputs, outputs = prepare_request_input_output(
  405. protocol_client,
  406. waveform,
  407. reference_text,
  408. target_text,
  409. sample_rate,
  410. padding_duration=padding_duration
  411. )
  412. request_id = str(uuid.uuid4())
  413. user_data = UserData()
  414. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  415. total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
  416. run_sync_streaming_inference,
  417. sync_triton_client,
  418. model_name,
  419. inputs,
  420. outputs,
  421. request_id,
  422. user_data,
  423. chunk_overlap_duration,
  424. save_sample_rate,
  425. audio_save_path
  426. )
  427. if total_request_latency is not None:
  428. print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
  429. latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
  430. total_duration += actual_duration
  431. else:
  432. print(f"{name}: Item {i} failed.")
  433. except FileNotFoundError:
  434. print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
  435. except Exception as e:
  436. print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
  437. import traceback
  438. traceback.print_exc()
  439. finally: # Ensure client is closed
  440. if sync_triton_client:
  441. try:
  442. print(f"{name}: Closing sync client...")
  443. sync_triton_client.close()
  444. except Exception as e:
  445. print(f"{name}: Error closing sync client: {e}")
  446. print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
  447. return total_duration, latency_data
  448. async def send(
  449. manifest_item_list: list,
  450. name: str,
  451. triton_client: tritonclient.grpc.aio.InferenceServerClient,
  452. protocol_client: types.ModuleType,
  453. log_interval: int,
  454. model_name: str,
  455. padding_duration: int = None,
  456. audio_save_dir: str = "./",
  457. save_sample_rate: int = 16000,
  458. ):
  459. total_duration = 0.0
  460. latency_data = []
  461. task_id = int(name[5:])
  462. print(f"manifest_item_list: {manifest_item_list}")
  463. for i, item in enumerate(manifest_item_list):
  464. if i % log_interval == 0:
  465. print(f"{name}: {i}/{len(manifest_item_list)}")
  466. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  467. reference_text, target_text = item["reference_text"], item["target_text"]
  468. inputs, outputs = prepare_request_input_output(
  469. protocol_client,
  470. waveform,
  471. reference_text,
  472. target_text,
  473. sample_rate,
  474. padding_duration=padding_duration
  475. )
  476. sequence_id = 100000000 + i + task_id * 10
  477. start = time.time()
  478. response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
  479. audio = response.as_numpy("waveform").reshape(-1)
  480. actual_duration = len(audio) / save_sample_rate
  481. end = time.time() - start
  482. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  483. sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
  484. latency_data.append((end, actual_duration))
  485. total_duration += actual_duration
  486. return total_duration, latency_data
  487. def load_manifests(manifest_path):
  488. with open(manifest_path, "r") as f:
  489. manifest_list = []
  490. for line in f:
  491. assert len(line.strip().split("|")) == 4
  492. utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
  493. utt = Path(utt).stem
  494. # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
  495. if not os.path.isabs(prompt_wav):
  496. prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
  497. manifest_list.append(
  498. {
  499. "audio_filepath": prompt_wav,
  500. "reference_text": prompt_text,
  501. "target_text": gt_text,
  502. "target_audio_path": utt,
  503. }
  504. )
  505. return manifest_list
  506. def split_data(data, k):
  507. n = len(data)
  508. if n < k:
  509. print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
  510. k = n
  511. quotient = n // k
  512. remainder = n % k
  513. result = []
  514. start = 0
  515. for i in range(k):
  516. if i < remainder:
  517. end = start + quotient + 1
  518. else:
  519. end = start + quotient
  520. result.append(data[start:end])
  521. start = end
  522. return result
  523. async def main():
  524. args = get_args()
  525. url = f"{args.server_addr}:{args.server_port}"
  526. # --- Client Initialization based on mode ---
  527. triton_client = None
  528. protocol_client = None
  529. if args.mode == "offline":
  530. print("Initializing gRPC client for offline mode...")
  531. # Use the async client for offline tasks
  532. triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  533. protocol_client = grpcclient_aio
  534. elif args.mode == "streaming":
  535. print("Initializing gRPC client for streaming mode...")
  536. # Use the sync client for streaming tasks, handled via asyncio.to_thread
  537. # We will create one sync client instance PER TASK inside send_streaming.
  538. # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
  539. protocol_client = grpcclient_sync # protocol client for input prep
  540. else:
  541. raise ValueError(f"Invalid mode: {args.mode}")
  542. # --- End Client Initialization ---
  543. if args.reference_audio:
  544. args.num_tasks = 1
  545. args.log_interval = 1
  546. manifest_item_list = [
  547. {
  548. "reference_text": args.reference_text,
  549. "target_text": args.target_text,
  550. "audio_filepath": args.reference_audio,
  551. "target_audio_path": "test",
  552. }
  553. ]
  554. elif args.huggingface_dataset:
  555. import datasets
  556. dataset = datasets.load_dataset(
  557. args.huggingface_dataset,
  558. split=args.split_name,
  559. trust_remote_code=True,
  560. )
  561. manifest_item_list = []
  562. for i in range(len(dataset)):
  563. manifest_item_list.append(
  564. {
  565. "audio_filepath": dataset[i]["prompt_audio"],
  566. "reference_text": dataset[i]["prompt_text"],
  567. "target_audio_path": dataset[i]["id"],
  568. "target_text": dataset[i]["target_text"],
  569. }
  570. )
  571. else:
  572. manifest_item_list = load_manifests(args.manifest_path)
  573. num_tasks = min(args.num_tasks, len(manifest_item_list))
  574. manifest_item_list = split_data(manifest_item_list, num_tasks)
  575. os.makedirs(args.log_dir, exist_ok=True)
  576. tasks = []
  577. start_time = time.time()
  578. for i in range(num_tasks):
  579. # --- Task Creation based on mode ---
  580. if args.mode == "offline":
  581. task = asyncio.create_task(
  582. send(
  583. manifest_item_list[i],
  584. name=f"task-{i}",
  585. triton_client=triton_client,
  586. protocol_client=protocol_client,
  587. log_interval=args.log_interval,
  588. model_name=args.model_name,
  589. audio_save_dir=args.log_dir,
  590. padding_duration=1,
  591. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  592. )
  593. )
  594. elif args.mode == "streaming":
  595. task = asyncio.create_task(
  596. send_streaming(
  597. manifest_item_list[i],
  598. name=f"task-{i}",
  599. server_url=url, # Pass URL instead of client
  600. protocol_client=protocol_client,
  601. log_interval=args.log_interval,
  602. model_name=args.model_name,
  603. audio_save_dir=args.log_dir,
  604. padding_duration=10,
  605. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  606. chunk_overlap_duration=args.chunk_overlap_duration,
  607. )
  608. )
  609. # --- End Task Creation ---
  610. tasks.append(task)
  611. ans_list = await asyncio.gather(*tasks)
  612. end_time = time.time()
  613. elapsed = end_time - start_time
  614. total_duration = 0.0
  615. latency_data = []
  616. for ans in ans_list:
  617. if ans:
  618. total_duration += ans[0]
  619. latency_data.extend(ans[1]) # Use extend for list of lists
  620. else:
  621. print("Warning: A task returned None, possibly due to an error.")
  622. if total_duration == 0:
  623. print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
  624. rtf = float('inf')
  625. else:
  626. rtf = elapsed / total_duration
  627. s = f"Mode: {args.mode}\n"
  628. s += f"RTF: {rtf:.4f}\n"
  629. s += f"total_duration: {total_duration:.3f} seconds\n"
  630. s += f"({total_duration / 3600:.2f} hours)\n"
  631. s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
  632. # --- Statistics Reporting based on mode ---
  633. if latency_data:
  634. if args.mode == "offline":
  635. # Original offline latency calculation
  636. latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
  637. if latency_list:
  638. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  639. latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
  640. s += f"latency_variance: {latency_variance:.2f}\n"
  641. s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
  642. s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
  643. s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
  644. s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
  645. s += f"average_latency_ms: {latency_ms:.2f}\n"
  646. else:
  647. s += "No latency data collected for offline mode.\n"
  648. elif args.mode == "streaming":
  649. # Calculate stats for total request latency and first chunk latency
  650. total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
  651. first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
  652. s += "\n--- Total Request Latency ---\n"
  653. if total_latency_list:
  654. avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
  655. variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
  656. s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
  657. s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
  658. s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
  659. s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
  660. s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
  661. s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
  662. else:
  663. s += "No total request latency data collected.\n"
  664. s += "\n--- First Chunk Latency ---\n"
  665. if first_chunk_latency_list:
  666. avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
  667. variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
  668. s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
  669. s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
  670. s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
  671. s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
  672. s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
  673. s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
  674. else:
  675. s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
  676. else:
  677. s += "No latency data collected.\n"
  678. # --- End Statistics Reporting ---
  679. print(s)
  680. if args.manifest_path:
  681. name = Path(args.manifest_path).stem
  682. elif args.split_name:
  683. name = args.split_name
  684. elif args.reference_audio:
  685. name = Path(args.reference_audio).stem
  686. else:
  687. name = "results" # Default name if no manifest/split/audio provided
  688. with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
  689. f.write(s)
  690. # --- Statistics Fetching using temporary Async Client ---
  691. # Use a separate async client for fetching stats regardless of mode
  692. stats_client = None
  693. try:
  694. print("Initializing temporary async client for fetching stats...")
  695. stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  696. print("Fetching inference statistics...")
  697. # Fetching for all models, filtering might be needed depending on server setup
  698. stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
  699. print("Fetching model config...")
  700. metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
  701. write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
  702. with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
  703. json.dump(metadata, f, indent=4)
  704. except Exception as e:
  705. print(f"Could not retrieve statistics or config: {e}")
  706. finally:
  707. if stats_client:
  708. try:
  709. print("Closing temporary async stats client...")
  710. await stats_client.close()
  711. except Exception as e:
  712. print(f"Error closing async stats client: {e}")
  713. # --- End Statistics Fetching ---
  714. if __name__ == "__main__":
  715. # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
  716. async def run_main():
  717. try:
  718. await main()
  719. except Exception as e:
  720. print(f"An error occurred in main: {e}")
  721. import traceback
  722. traceback.print_exc()
  723. asyncio.run(run_main())