client_grpc.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856
  1. # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
  2. # 2023 Nvidia (authors: Yuekai Zhang)
  3. # 2023 Recurrent.ai (authors: Songtao Shi)
  4. # See LICENSE for clarification regarding multiple authors
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """
  18. This script supports to load dataset from huggingface and sends it to the server
  19. for decoding, in parallel.
  20. Usage:
  21. num_task=2
  22. # For offline F5-TTS
  23. python3 client_grpc.py \
  24. --server-addr localhost \
  25. --model-name f5_tts \
  26. --num-tasks $num_task \
  27. --huggingface-dataset yuekai/seed_tts \
  28. --split-name test_zh \
  29. --log-dir ./log_concurrent_tasks_${num_task}
  30. # For offline Spark-TTS-0.5B
  31. python3 client_grpc.py \
  32. --server-addr localhost \
  33. --model-name spark_tts \
  34. --num-tasks $num_task \
  35. --huggingface-dataset yuekai/seed_tts \
  36. --split-name wenetspeech4tts \
  37. --log-dir ./log_concurrent_tasks_${num_task}
  38. """
  39. import argparse
  40. import asyncio
  41. import json
  42. import queue # Added
  43. import uuid # Added
  44. import functools # Added
  45. import os
  46. import time
  47. import types
  48. from pathlib import Path
  49. import numpy as np
  50. import soundfile as sf
  51. import tritonclient
  52. import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
  53. import tritonclient.grpc as grpcclient_sync # Added sync client import
  54. from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
  55. # --- Added UserData and callback ---
  56. class UserData:
  57. def __init__(self):
  58. self._completed_requests = queue.Queue()
  59. self._first_chunk_time = None
  60. self._start_time = None
  61. def record_start_time(self):
  62. self._start_time = time.time()
  63. def get_first_chunk_latency(self):
  64. if self._first_chunk_time and self._start_time:
  65. return self._first_chunk_time - self._start_time
  66. return None
  67. def callback(user_data, result, error):
  68. if user_data._first_chunk_time is None and not error:
  69. user_data._first_chunk_time = time.time() # Record time of first successful chunk
  70. if error:
  71. user_data._completed_requests.put(error)
  72. else:
  73. user_data._completed_requests.put(result)
  74. # --- End Added UserData and callback ---
  75. def write_triton_stats(stats, summary_file):
  76. with open(summary_file, "w") as summary_f:
  77. model_stats = stats["model_stats"]
  78. # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
  79. summary_f.write(
  80. "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
  81. )
  82. summary_f.write("To learn more about the log, please refer to: \n")
  83. summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
  84. summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
  85. summary_f.write(
  86. "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
  87. )
  88. summary_f.write(
  89. "However, there is a trade-off between the increased queue time and the increased batch size. \n"
  90. )
  91. summary_f.write(
  92. "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
  93. )
  94. summary_f.write(
  95. "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
  96. )
  97. for model_state in model_stats:
  98. if "last_inference" not in model_state:
  99. continue
  100. summary_f.write(f"model name is {model_state['name']} \n")
  101. model_inference_stats = model_state["inference_stats"]
  102. total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
  103. total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
  104. total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
  105. total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
  106. summary_f.write(
  107. f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
  108. )
  109. model_batch_stats = model_state["batch_stats"]
  110. for batch in model_batch_stats:
  111. batch_size = int(batch["batch_size"])
  112. compute_input = batch["compute_input"]
  113. compute_output = batch["compute_output"]
  114. compute_infer = batch["compute_infer"]
  115. batch_count = int(compute_infer["count"])
  116. assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
  117. compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
  118. compute_input_time_ms = int(compute_input["ns"]) / 1e6
  119. compute_output_time_ms = int(compute_output["ns"]) / 1e6
  120. summary_f.write(
  121. f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
  122. )
  123. summary_f.write(
  124. f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
  125. )
  126. summary_f.write(
  127. f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
  128. )
  129. def get_args():
  130. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  131. parser.add_argument(
  132. "--server-addr",
  133. type=str,
  134. default="localhost",
  135. help="Address of the server",
  136. )
  137. parser.add_argument(
  138. "--server-port",
  139. type=int,
  140. default=8001,
  141. help="Grpc port of the triton server, default is 8001",
  142. )
  143. parser.add_argument(
  144. "--reference-audio",
  145. type=str,
  146. default=None,
  147. help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
  148. )
  149. parser.add_argument(
  150. "--reference-text",
  151. type=str,
  152. default="",
  153. help="",
  154. )
  155. parser.add_argument(
  156. "--target-text",
  157. type=str,
  158. default="",
  159. help="",
  160. )
  161. parser.add_argument(
  162. "--huggingface-dataset",
  163. type=str,
  164. default="yuekai/seed_tts",
  165. help="dataset name in huggingface dataset hub",
  166. )
  167. parser.add_argument(
  168. "--split-name",
  169. type=str,
  170. default="wenetspeech4tts",
  171. choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
  172. help="dataset split name, default is 'test'",
  173. )
  174. parser.add_argument(
  175. "--manifest-path",
  176. type=str,
  177. default=None,
  178. help="Path to the manifest dir which includes wav.scp trans.txt files.",
  179. )
  180. parser.add_argument(
  181. "--model-name",
  182. type=str,
  183. default="f5_tts",
  184. choices=[
  185. "f5_tts",
  186. "spark_tts",
  187. "cosyvoice2"],
  188. help="triton model_repo module name to request",
  189. )
  190. parser.add_argument(
  191. "--num-tasks",
  192. type=int,
  193. default=1,
  194. help="Number of concurrent tasks for sending",
  195. )
  196. parser.add_argument(
  197. "--log-interval",
  198. type=int,
  199. default=5,
  200. help="Controls how frequently we print the log.",
  201. )
  202. parser.add_argument(
  203. "--compute-wer",
  204. action="store_true",
  205. default=False,
  206. help="""True to compute WER.
  207. """,
  208. )
  209. parser.add_argument(
  210. "--log-dir",
  211. type=str,
  212. required=False,
  213. default="./tmp",
  214. help="log directory",
  215. )
  216. # --- Added arguments ---
  217. parser.add_argument(
  218. "--mode",
  219. type=str,
  220. default="offline",
  221. choices=["offline", "streaming"],
  222. help="Select offline or streaming benchmark mode."
  223. )
  224. parser.add_argument(
  225. "--chunk-overlap-duration",
  226. type=float,
  227. default=0.1,
  228. help="Chunk overlap duration for streaming reconstruction (in seconds)."
  229. )
  230. parser.add_argument(
  231. "--use-spk2info-cache",
  232. type=bool,
  233. default=False,
  234. help="Use spk2info cache for reference audio.",
  235. )
  236. return parser.parse_args()
  237. def load_audio(wav_path, target_sample_rate=16000):
  238. assert target_sample_rate == 16000, "hard coding in server"
  239. if isinstance(wav_path, dict):
  240. waveform = wav_path["array"]
  241. sample_rate = wav_path["sampling_rate"]
  242. else:
  243. waveform, sample_rate = sf.read(wav_path)
  244. if sample_rate != target_sample_rate:
  245. from scipy.signal import resample
  246. num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
  247. waveform = resample(waveform, num_samples)
  248. return waveform, target_sample_rate
  249. def prepare_request_input_output(
  250. protocol_client, # Can be grpcclient_aio or grpcclient_sync
  251. waveform,
  252. reference_text,
  253. target_text,
  254. sample_rate=16000,
  255. padding_duration: int = None, # Optional padding for offline mode
  256. use_spk2info_cache: bool = False
  257. ):
  258. """Prepares inputs for Triton inference (offline or streaming)."""
  259. assert len(waveform.shape) == 1, "waveform should be 1D"
  260. lengths = np.array([[len(waveform)]], dtype=np.int32)
  261. # Apply padding only if padding_duration is provided (for offline)
  262. if padding_duration:
  263. duration = len(waveform) / sample_rate
  264. # Estimate target duration based on text length ratio (crude estimation)
  265. # Avoid division by zero if reference_text is empty
  266. if reference_text:
  267. estimated_target_duration = duration / len(reference_text) * len(target_text)
  268. else:
  269. estimated_target_duration = duration # Assume target duration similar to reference if no text
  270. # Calculate required samples based on estimated total duration
  271. required_total_samples = padding_duration * sample_rate * (
  272. (int(estimated_target_duration + duration) // padding_duration) + 1
  273. )
  274. samples = np.zeros((1, required_total_samples), dtype=np.float32)
  275. samples[0, : len(waveform)] = waveform
  276. else:
  277. # No padding for streaming or if padding_duration is None
  278. samples = waveform.reshape(1, -1).astype(np.float32)
  279. # Common input creation logic
  280. inputs = [
  281. protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
  282. protocol_client.InferInput(
  283. "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
  284. ),
  285. protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
  286. protocol_client.InferInput("target_text", [1, 1], "BYTES"),
  287. ]
  288. inputs[0].set_data_from_numpy(samples)
  289. inputs[1].set_data_from_numpy(lengths)
  290. input_data_numpy = np.array([reference_text], dtype=object)
  291. input_data_numpy = input_data_numpy.reshape((1, 1))
  292. inputs[2].set_data_from_numpy(input_data_numpy)
  293. input_data_numpy = np.array([target_text], dtype=object)
  294. input_data_numpy = input_data_numpy.reshape((1, 1))
  295. inputs[3].set_data_from_numpy(input_data_numpy)
  296. outputs = [protocol_client.InferRequestedOutput("waveform")]
  297. if use_spk2info_cache:
  298. inputs = inputs[-1:]
  299. return inputs, outputs
  300. def run_sync_streaming_inference(
  301. sync_triton_client: tritonclient.grpc.InferenceServerClient,
  302. model_name: str,
  303. inputs: list,
  304. outputs: list,
  305. request_id: str,
  306. user_data: UserData,
  307. chunk_overlap_duration: float,
  308. save_sample_rate: int,
  309. audio_save_path: str,
  310. ):
  311. """Helper function to run the blocking sync streaming call."""
  312. start_time_total = time.time()
  313. user_data.record_start_time() # Record start time for first chunk latency calculation
  314. # Establish stream
  315. sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
  316. # Send request
  317. sync_triton_client.async_stream_infer(
  318. model_name,
  319. inputs,
  320. request_id=request_id,
  321. outputs=outputs,
  322. enable_empty_final_response=True,
  323. )
  324. # Process results
  325. audios = []
  326. while True:
  327. try:
  328. result = user_data._completed_requests.get() # Add timeout
  329. if isinstance(result, InferenceServerException):
  330. print(f"Received InferenceServerException: {result}")
  331. sync_triton_client.stop_stream()
  332. return None, None, None # Indicate error
  333. # Get response metadata
  334. response = result.get_response()
  335. final = response.parameters["triton_final_response"].bool_param
  336. if final is True:
  337. break
  338. audio_chunk = result.as_numpy("waveform").reshape(-1)
  339. if audio_chunk.size > 0: # Only append non-empty chunks
  340. audios.append(audio_chunk)
  341. else:
  342. print("Warning: received empty audio chunk.")
  343. except queue.Empty:
  344. print(f"Timeout waiting for response for request id {request_id}")
  345. sync_triton_client.stop_stream()
  346. return None, None, None # Indicate error
  347. sync_triton_client.stop_stream()
  348. end_time_total = time.time()
  349. total_request_latency = end_time_total - start_time_total
  350. first_chunk_latency = user_data.get_first_chunk_latency()
  351. # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
  352. actual_duration = 0
  353. if audios:
  354. # Only spark_tts model uses cross-fade
  355. if model_name == "spark_tts":
  356. cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
  357. fade_out = np.linspace(1, 0, cross_fade_samples)
  358. fade_in = np.linspace(0, 1, cross_fade_samples)
  359. reconstructed_audio = None
  360. # Simplified reconstruction based on client_grpc_streaming.py
  361. if not audios:
  362. print("Warning: No audio chunks received.")
  363. reconstructed_audio = np.array([], dtype=np.float32) # Empty array
  364. elif len(audios) == 1:
  365. reconstructed_audio = audios[0]
  366. else:
  367. reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
  368. for i in range(1, len(audios)):
  369. # Cross-fade section
  370. cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
  371. audios[i - 1][-cross_fade_samples:] * fade_out)
  372. # Middle section of the current chunk
  373. middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
  374. # Concatenate
  375. reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
  376. # Add the last part of the final chunk
  377. reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
  378. if reconstructed_audio is not None and reconstructed_audio.size > 0:
  379. actual_duration = len(reconstructed_audio) / save_sample_rate
  380. # Save reconstructed audio
  381. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  382. else:
  383. print("Warning: No audio chunks received or reconstructed.")
  384. actual_duration = 0 # Set duration to 0 if no audio
  385. else:
  386. reconstructed_audio = np.concatenate(audios)
  387. print(f"reconstructed_audio: {reconstructed_audio.shape}")
  388. actual_duration = len(reconstructed_audio) / save_sample_rate
  389. # Save reconstructed audio
  390. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  391. else:
  392. print("Warning: No audio chunks received.")
  393. actual_duration = 0
  394. return total_request_latency, first_chunk_latency, actual_duration
  395. async def send_streaming(
  396. manifest_item_list: list,
  397. name: str,
  398. server_url: str, # Changed from sync_triton_client
  399. protocol_client: types.ModuleType,
  400. log_interval: int,
  401. model_name: str,
  402. audio_save_dir: str = "./",
  403. save_sample_rate: int = 16000,
  404. chunk_overlap_duration: float = 0.1,
  405. padding_duration: int = None,
  406. use_spk2info_cache: bool = False,
  407. ):
  408. total_duration = 0.0
  409. latency_data = []
  410. task_id = int(name[5:])
  411. sync_triton_client = None # Initialize client variable
  412. try: # Wrap in try...finally to ensure client closing
  413. print(f"{name}: Initializing sync client for streaming...")
  414. sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
  415. print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
  416. for i, item in enumerate(manifest_item_list):
  417. if i % log_interval == 0:
  418. print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
  419. try:
  420. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  421. reference_text, target_text = item["reference_text"], item["target_text"]
  422. inputs, outputs = prepare_request_input_output(
  423. protocol_client,
  424. waveform,
  425. reference_text,
  426. target_text,
  427. sample_rate,
  428. padding_duration=padding_duration,
  429. use_spk2info_cache=use_spk2info_cache
  430. )
  431. request_id = str(uuid.uuid4())
  432. user_data = UserData()
  433. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  434. total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
  435. run_sync_streaming_inference,
  436. sync_triton_client,
  437. model_name,
  438. inputs,
  439. outputs,
  440. request_id,
  441. user_data,
  442. chunk_overlap_duration,
  443. save_sample_rate,
  444. audio_save_path
  445. )
  446. if total_request_latency is not None:
  447. print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
  448. latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
  449. total_duration += actual_duration
  450. else:
  451. print(f"{name}: Item {i} failed.")
  452. except FileNotFoundError:
  453. print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
  454. except Exception as e:
  455. print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
  456. import traceback
  457. traceback.print_exc()
  458. finally: # Ensure client is closed
  459. if sync_triton_client:
  460. try:
  461. print(f"{name}: Closing sync client...")
  462. sync_triton_client.close()
  463. except Exception as e:
  464. print(f"{name}: Error closing sync client: {e}")
  465. print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
  466. return total_duration, latency_data
  467. async def send(
  468. manifest_item_list: list,
  469. name: str,
  470. triton_client: tritonclient.grpc.aio.InferenceServerClient,
  471. protocol_client: types.ModuleType,
  472. log_interval: int,
  473. model_name: str,
  474. padding_duration: int = None,
  475. audio_save_dir: str = "./",
  476. save_sample_rate: int = 16000,
  477. use_spk2info_cache: bool = False,
  478. ):
  479. total_duration = 0.0
  480. latency_data = []
  481. task_id = int(name[5:])
  482. print(f"manifest_item_list: {manifest_item_list}")
  483. for i, item in enumerate(manifest_item_list):
  484. if i % log_interval == 0:
  485. print(f"{name}: {i}/{len(manifest_item_list)}")
  486. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  487. reference_text, target_text = item["reference_text"], item["target_text"]
  488. inputs, outputs = prepare_request_input_output(
  489. protocol_client,
  490. waveform,
  491. reference_text,
  492. target_text,
  493. sample_rate,
  494. padding_duration=padding_duration,
  495. use_spk2info_cache=use_spk2info_cache
  496. )
  497. sequence_id = 100000000 + i + task_id * 10
  498. start = time.time()
  499. response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
  500. audio = response.as_numpy("waveform").reshape(-1)
  501. actual_duration = len(audio) / save_sample_rate
  502. end = time.time() - start
  503. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  504. sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
  505. latency_data.append((end, actual_duration))
  506. total_duration += actual_duration
  507. return total_duration, latency_data
  508. def load_manifests(manifest_path):
  509. with open(manifest_path, "r") as f:
  510. manifest_list = []
  511. for line in f:
  512. assert len(line.strip().split("|")) == 4
  513. utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
  514. utt = Path(utt).stem
  515. # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
  516. if not os.path.isabs(prompt_wav):
  517. prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
  518. manifest_list.append(
  519. {
  520. "audio_filepath": prompt_wav,
  521. "reference_text": prompt_text,
  522. "target_text": gt_text,
  523. "target_audio_path": utt,
  524. }
  525. )
  526. return manifest_list
  527. def split_data(data, k):
  528. n = len(data)
  529. if n < k:
  530. print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
  531. k = n
  532. quotient = n // k
  533. remainder = n % k
  534. result = []
  535. start = 0
  536. for i in range(k):
  537. if i < remainder:
  538. end = start + quotient + 1
  539. else:
  540. end = start + quotient
  541. result.append(data[start:end])
  542. start = end
  543. return result
  544. async def main():
  545. args = get_args()
  546. url = f"{args.server_addr}:{args.server_port}"
  547. # --- Client Initialization based on mode ---
  548. triton_client = None
  549. protocol_client = None
  550. if args.mode == "offline":
  551. print("Initializing gRPC client for offline mode...")
  552. # Use the async client for offline tasks
  553. triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  554. protocol_client = grpcclient_aio
  555. elif args.mode == "streaming":
  556. print("Initializing gRPC client for streaming mode...")
  557. # Use the sync client for streaming tasks, handled via asyncio.to_thread
  558. # We will create one sync client instance PER TASK inside send_streaming.
  559. # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
  560. protocol_client = grpcclient_sync # protocol client for input prep
  561. else:
  562. raise ValueError(f"Invalid mode: {args.mode}")
  563. # --- End Client Initialization ---
  564. if args.reference_audio:
  565. args.num_tasks = 1
  566. args.log_interval = 1
  567. manifest_item_list = [
  568. {
  569. "reference_text": args.reference_text,
  570. "target_text": args.target_text,
  571. "audio_filepath": args.reference_audio,
  572. "target_audio_path": "test",
  573. }
  574. ]
  575. elif args.huggingface_dataset:
  576. import datasets
  577. dataset = datasets.load_dataset(
  578. args.huggingface_dataset,
  579. split=args.split_name,
  580. trust_remote_code=True,
  581. )
  582. manifest_item_list = []
  583. for i in range(len(dataset)):
  584. manifest_item_list.append(
  585. {
  586. "audio_filepath": dataset[i]["prompt_audio"],
  587. "reference_text": dataset[i]["prompt_text"],
  588. "target_audio_path": dataset[i]["id"],
  589. "target_text": dataset[i]["target_text"],
  590. }
  591. )
  592. else:
  593. manifest_item_list = load_manifests(args.manifest_path)
  594. num_tasks = min(args.num_tasks, len(manifest_item_list))
  595. manifest_item_list = split_data(manifest_item_list, num_tasks)
  596. os.makedirs(args.log_dir, exist_ok=True)
  597. tasks = []
  598. start_time = time.time()
  599. for i in range(num_tasks):
  600. # --- Task Creation based on mode ---
  601. if args.mode == "offline":
  602. task = asyncio.create_task(
  603. send(
  604. manifest_item_list[i],
  605. name=f"task-{i}",
  606. triton_client=triton_client,
  607. protocol_client=protocol_client,
  608. log_interval=args.log_interval,
  609. model_name=args.model_name,
  610. audio_save_dir=args.log_dir,
  611. padding_duration=1,
  612. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  613. use_spk2info_cache=args.use_spk2info_cache,
  614. )
  615. )
  616. elif args.mode == "streaming":
  617. task = asyncio.create_task(
  618. send_streaming(
  619. manifest_item_list[i],
  620. name=f"task-{i}",
  621. server_url=url, # Pass URL instead of client
  622. protocol_client=protocol_client,
  623. log_interval=args.log_interval,
  624. model_name=args.model_name,
  625. audio_save_dir=args.log_dir,
  626. padding_duration=10,
  627. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  628. chunk_overlap_duration=args.chunk_overlap_duration,
  629. use_spk2info_cache=args.use_spk2info_cache,
  630. )
  631. )
  632. # --- End Task Creation ---
  633. tasks.append(task)
  634. ans_list = await asyncio.gather(*tasks)
  635. end_time = time.time()
  636. elapsed = end_time - start_time
  637. total_duration = 0.0
  638. latency_data = []
  639. for ans in ans_list:
  640. if ans:
  641. total_duration += ans[0]
  642. latency_data.extend(ans[1]) # Use extend for list of lists
  643. else:
  644. print("Warning: A task returned None, possibly due to an error.")
  645. if total_duration == 0:
  646. print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
  647. rtf = float('inf')
  648. else:
  649. rtf = elapsed / total_duration
  650. s = f"Mode: {args.mode}\n"
  651. s += f"RTF: {rtf:.4f}\n"
  652. s += f"total_duration: {total_duration:.3f} seconds\n"
  653. s += f"({total_duration / 3600:.2f} hours)\n"
  654. s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
  655. # --- Statistics Reporting based on mode ---
  656. if latency_data:
  657. if args.mode == "offline":
  658. # Original offline latency calculation
  659. latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
  660. if latency_list:
  661. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  662. latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
  663. s += f"latency_variance: {latency_variance:.2f}\n"
  664. s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
  665. s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
  666. s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
  667. s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
  668. s += f"average_latency_ms: {latency_ms:.2f}\n"
  669. else:
  670. s += "No latency data collected for offline mode.\n"
  671. elif args.mode == "streaming":
  672. # Calculate stats for total request latency and first chunk latency
  673. total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
  674. first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
  675. s += "\n--- Total Request Latency ---\n"
  676. if total_latency_list:
  677. avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
  678. variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
  679. s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
  680. s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
  681. s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
  682. s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
  683. s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
  684. s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
  685. else:
  686. s += "No total request latency data collected.\n"
  687. s += "\n--- First Chunk Latency ---\n"
  688. if first_chunk_latency_list:
  689. avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
  690. variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
  691. s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
  692. s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
  693. s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
  694. s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
  695. s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
  696. s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
  697. else:
  698. s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
  699. else:
  700. s += "No latency data collected.\n"
  701. # --- End Statistics Reporting ---
  702. print(s)
  703. if args.manifest_path:
  704. name = Path(args.manifest_path).stem
  705. elif args.split_name:
  706. name = args.split_name
  707. elif args.reference_audio:
  708. name = Path(args.reference_audio).stem
  709. else:
  710. name = "results" # Default name if no manifest/split/audio provided
  711. with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
  712. f.write(s)
  713. # --- Statistics Fetching using temporary Async Client ---
  714. # Use a separate async client for fetching stats regardless of mode
  715. stats_client = None
  716. try:
  717. print("Initializing temporary async client for fetching stats...")
  718. stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  719. print("Fetching inference statistics...")
  720. # Fetching for all models, filtering might be needed depending on server setup
  721. stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
  722. print("Fetching model config...")
  723. metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
  724. write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
  725. with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
  726. json.dump(metadata, f, indent=4)
  727. except Exception as e:
  728. print(f"Could not retrieve statistics or config: {e}")
  729. finally:
  730. if stats_client:
  731. try:
  732. print("Closing temporary async stats client...")
  733. await stats_client.close()
  734. except Exception as e:
  735. print(f"Error closing async stats client: {e}")
  736. # --- End Statistics Fetching ---
  737. if __name__ == "__main__":
  738. # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
  739. async def run_main():
  740. try:
  741. await main()
  742. except Exception as e:
  743. print(f"An error occurred in main: {e}")
  744. import traceback
  745. traceback.print_exc()
  746. asyncio.run(run_main())