client_grpc.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862
  1. # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
  2. # 2023 Nvidia (authors: Yuekai Zhang)
  3. # 2023 Recurrent.ai (authors: Songtao Shi)
  4. # See LICENSE for clarification regarding multiple authors
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """
  18. This script supports to load dataset from huggingface and sends it to the server
  19. for decoding, in parallel.
  20. Usage:
  21. num_task=2
  22. # For offline F5-TTS
  23. python3 client_grpc.py \
  24. --server-addr localhost \
  25. --model-name f5_tts \
  26. --num-tasks $num_task \
  27. --huggingface-dataset yuekai/seed_tts \
  28. --split-name test_zh \
  29. --log-dir ./log_concurrent_tasks_${num_task}
  30. # For offline Spark-TTS-0.5B
  31. python3 client_grpc.py \
  32. --server-addr localhost \
  33. --model-name spark_tts \
  34. --num-tasks $num_task \
  35. --huggingface-dataset yuekai/seed_tts \
  36. --split-name wenetspeech4tts \
  37. --log-dir ./log_concurrent_tasks_${num_task}
  38. """
  39. import argparse
  40. import asyncio
  41. import json
  42. import queue # Added
  43. import uuid # Added
  44. import functools # Added
  45. import os
  46. import time
  47. import types
  48. from pathlib import Path
  49. import numpy as np
  50. import soundfile as sf
  51. import tritonclient
  52. import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
  53. import tritonclient.grpc as grpcclient_sync # Added sync client import
  54. from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
  55. # --- Added UserData and callback ---
  56. class UserData:
  57. def __init__(self):
  58. self._completed_requests = queue.Queue()
  59. self._first_chunk_time = None
  60. self._start_time = None
  61. def record_start_time(self):
  62. self._start_time = time.time()
  63. def get_first_chunk_latency(self):
  64. if self._first_chunk_time and self._start_time:
  65. return self._first_chunk_time - self._start_time
  66. return None
  67. def callback(user_data, result, error):
  68. if user_data._first_chunk_time is None and not error:
  69. user_data._first_chunk_time = time.time() # Record time of first successful chunk
  70. if error:
  71. user_data._completed_requests.put(error)
  72. else:
  73. user_data._completed_requests.put(result)
  74. # --- End Added UserData and callback ---
  75. def write_triton_stats(stats, summary_file):
  76. with open(summary_file, "w") as summary_f:
  77. model_stats = stats["model_stats"]
  78. # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
  79. summary_f.write(
  80. "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
  81. )
  82. summary_f.write("To learn more about the log, please refer to: \n")
  83. summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
  84. summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
  85. summary_f.write(
  86. "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
  87. )
  88. summary_f.write(
  89. "However, there is a trade-off between the increased queue time and the increased batch size. \n"
  90. )
  91. summary_f.write(
  92. "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
  93. )
  94. summary_f.write(
  95. "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
  96. )
  97. for model_state in model_stats:
  98. if "last_inference" not in model_state:
  99. continue
  100. summary_f.write(f"model name is {model_state['name']} \n")
  101. model_inference_stats = model_state["inference_stats"]
  102. total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
  103. total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
  104. total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
  105. total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
  106. summary_f.write(
  107. f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
  108. )
  109. model_batch_stats = model_state["batch_stats"]
  110. for batch in model_batch_stats:
  111. batch_size = int(batch["batch_size"])
  112. compute_input = batch["compute_input"]
  113. compute_output = batch["compute_output"]
  114. compute_infer = batch["compute_infer"]
  115. batch_count = int(compute_infer["count"])
  116. assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
  117. compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
  118. compute_input_time_ms = int(compute_input["ns"]) / 1e6
  119. compute_output_time_ms = int(compute_output["ns"]) / 1e6
  120. summary_f.write(
  121. f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
  122. )
  123. summary_f.write(
  124. f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
  125. )
  126. summary_f.write(
  127. f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
  128. )
  129. def get_args():
  130. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  131. parser.add_argument(
  132. "--server-addr",
  133. type=str,
  134. default="localhost",
  135. help="Address of the server",
  136. )
  137. parser.add_argument(
  138. "--server-port",
  139. type=int,
  140. default=8001,
  141. help="Grpc port of the triton server, default is 8001",
  142. )
  143. parser.add_argument(
  144. "--reference-audio",
  145. type=str,
  146. default=None,
  147. help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
  148. )
  149. parser.add_argument(
  150. "--reference-text",
  151. type=str,
  152. default="",
  153. help="",
  154. )
  155. parser.add_argument(
  156. "--target-text",
  157. type=str,
  158. default="",
  159. help="",
  160. )
  161. parser.add_argument(
  162. "--huggingface-dataset",
  163. type=str,
  164. default="yuekai/seed_tts",
  165. help="dataset name in huggingface dataset hub",
  166. )
  167. parser.add_argument(
  168. "--split-name",
  169. type=str,
  170. default="wenetspeech4tts",
  171. choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
  172. help="dataset split name, default is 'test'",
  173. )
  174. parser.add_argument(
  175. "--manifest-path",
  176. type=str,
  177. default=None,
  178. help="Path to the manifest dir which includes wav.scp trans.txt files.",
  179. )
  180. parser.add_argument(
  181. "--model-name",
  182. type=str,
  183. default="f5_tts",
  184. choices=[
  185. "f5_tts",
  186. "spark_tts",
  187. "cosyvoice2",
  188. "cosyvoice2_dit"],
  189. help="triton model_repo module name to request",
  190. )
  191. parser.add_argument(
  192. "--num-tasks",
  193. type=int,
  194. default=1,
  195. help="Number of concurrent tasks for sending",
  196. )
  197. parser.add_argument(
  198. "--log-interval",
  199. type=int,
  200. default=5,
  201. help="Controls how frequently we print the log.",
  202. )
  203. parser.add_argument(
  204. "--compute-wer",
  205. action="store_true",
  206. default=False,
  207. help="""True to compute WER.
  208. """,
  209. )
  210. parser.add_argument(
  211. "--log-dir",
  212. type=str,
  213. required=False,
  214. default="./tmp",
  215. help="log directory",
  216. )
  217. # --- Added arguments ---
  218. parser.add_argument(
  219. "--mode",
  220. type=str,
  221. default="offline",
  222. choices=["offline", "streaming"],
  223. help="Select offline or streaming benchmark mode."
  224. )
  225. parser.add_argument(
  226. "--chunk-overlap-duration",
  227. type=float,
  228. default=0.1,
  229. help="Chunk overlap duration for streaming reconstruction (in seconds)."
  230. )
  231. parser.add_argument(
  232. "--use-spk2info-cache",
  233. type=str,
  234. default="False",
  235. help="Use spk2info cache for reference audio.",
  236. )
  237. return parser.parse_args()
  238. def load_audio(wav_path, target_sample_rate=16000):
  239. assert target_sample_rate == 16000, "hard coding in server"
  240. if isinstance(wav_path, dict):
  241. waveform = wav_path["array"]
  242. sample_rate = wav_path["sampling_rate"]
  243. else:
  244. waveform, sample_rate = sf.read(wav_path)
  245. if sample_rate != target_sample_rate:
  246. from scipy.signal import resample
  247. num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
  248. waveform = resample(waveform, num_samples)
  249. return waveform, target_sample_rate
  250. def prepare_request_input_output(
  251. protocol_client, # Can be grpcclient_aio or grpcclient_sync
  252. waveform,
  253. reference_text,
  254. target_text,
  255. sample_rate=16000,
  256. padding_duration: int = None, # Optional padding for offline mode
  257. use_spk2info_cache: bool = False
  258. ):
  259. """Prepares inputs for Triton inference (offline or streaming)."""
  260. assert len(waveform.shape) == 1, "waveform should be 1D"
  261. lengths = np.array([[len(waveform)]], dtype=np.int32)
  262. # Apply padding only if padding_duration is provided (for offline)
  263. if padding_duration:
  264. duration = len(waveform) / sample_rate
  265. # Estimate target duration based on text length ratio (crude estimation)
  266. # Avoid division by zero if reference_text is empty
  267. if reference_text:
  268. estimated_target_duration = duration / len(reference_text) * len(target_text)
  269. else:
  270. estimated_target_duration = duration # Assume target duration similar to reference if no text
  271. # Calculate required samples based on estimated total duration
  272. required_total_samples = padding_duration * sample_rate * (
  273. (int(estimated_target_duration + duration) // padding_duration) + 1
  274. )
  275. samples = np.zeros((1, required_total_samples), dtype=np.float32)
  276. samples[0, : len(waveform)] = waveform
  277. else:
  278. # No padding for streaming or if padding_duration is None
  279. samples = waveform.reshape(1, -1).astype(np.float32)
  280. # Common input creation logic
  281. inputs = [
  282. protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
  283. protocol_client.InferInput(
  284. "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
  285. ),
  286. protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
  287. protocol_client.InferInput("target_text", [1, 1], "BYTES"),
  288. ]
  289. inputs[0].set_data_from_numpy(samples)
  290. inputs[1].set_data_from_numpy(lengths)
  291. input_data_numpy = np.array([reference_text], dtype=object)
  292. input_data_numpy = input_data_numpy.reshape((1, 1))
  293. inputs[2].set_data_from_numpy(input_data_numpy)
  294. input_data_numpy = np.array([target_text], dtype=object)
  295. input_data_numpy = input_data_numpy.reshape((1, 1))
  296. inputs[3].set_data_from_numpy(input_data_numpy)
  297. outputs = [protocol_client.InferRequestedOutput("waveform")]
  298. if use_spk2info_cache:
  299. inputs = inputs[-1:]
  300. return inputs, outputs
  301. def run_sync_streaming_inference(
  302. sync_triton_client: tritonclient.grpc.InferenceServerClient,
  303. model_name: str,
  304. inputs: list,
  305. outputs: list,
  306. request_id: str,
  307. user_data: UserData,
  308. chunk_overlap_duration: float,
  309. save_sample_rate: int,
  310. audio_save_path: str,
  311. ):
  312. """Helper function to run the blocking sync streaming call."""
  313. start_time_total = time.time()
  314. user_data.record_start_time() # Record start time for first chunk latency calculation
  315. # Establish stream
  316. sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
  317. # Send request
  318. sync_triton_client.async_stream_infer(
  319. model_name,
  320. inputs,
  321. request_id=request_id,
  322. outputs=outputs,
  323. enable_empty_final_response=True,
  324. )
  325. # Process results
  326. audios = []
  327. while True:
  328. try:
  329. result = user_data._completed_requests.get() # Add timeout
  330. if isinstance(result, InferenceServerException):
  331. print(f"Received InferenceServerException: {result}")
  332. sync_triton_client.stop_stream()
  333. return None, None, None # Indicate error
  334. # Get response metadata
  335. response = result.get_response()
  336. final = response.parameters["triton_final_response"].bool_param
  337. if final is True:
  338. break
  339. audio_chunk = result.as_numpy("waveform").reshape(-1)
  340. if audio_chunk.size > 0: # Only append non-empty chunks
  341. audios.append(audio_chunk)
  342. else:
  343. print("Warning: received empty audio chunk.")
  344. except queue.Empty:
  345. print(f"Timeout waiting for response for request id {request_id}")
  346. sync_triton_client.stop_stream()
  347. return None, None, None # Indicate error
  348. sync_triton_client.stop_stream()
  349. end_time_total = time.time()
  350. total_request_latency = end_time_total - start_time_total
  351. first_chunk_latency = user_data.get_first_chunk_latency()
  352. # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
  353. actual_duration = 0
  354. if audios:
  355. # Only spark_tts model uses cross-fade
  356. if model_name == "spark_tts":
  357. cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
  358. fade_out = np.linspace(1, 0, cross_fade_samples)
  359. fade_in = np.linspace(0, 1, cross_fade_samples)
  360. reconstructed_audio = None
  361. # Simplified reconstruction based on client_grpc_streaming.py
  362. if not audios:
  363. print("Warning: No audio chunks received.")
  364. reconstructed_audio = np.array([], dtype=np.float32) # Empty array
  365. elif len(audios) == 1:
  366. reconstructed_audio = audios[0]
  367. else:
  368. reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
  369. for i in range(1, len(audios)):
  370. # Cross-fade section
  371. cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
  372. audios[i - 1][-cross_fade_samples:] * fade_out)
  373. # Middle section of the current chunk
  374. middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
  375. # Concatenate
  376. reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
  377. # Add the last part of the final chunk
  378. reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
  379. if reconstructed_audio is not None and reconstructed_audio.size > 0:
  380. actual_duration = len(reconstructed_audio) / save_sample_rate
  381. # Save reconstructed audio
  382. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  383. else:
  384. print("Warning: No audio chunks received or reconstructed.")
  385. actual_duration = 0 # Set duration to 0 if no audio
  386. else:
  387. reconstructed_audio = np.concatenate(audios)
  388. print(f"reconstructed_audio: {reconstructed_audio.shape}")
  389. actual_duration = len(reconstructed_audio) / save_sample_rate
  390. # Save reconstructed audio
  391. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  392. else:
  393. print("Warning: No audio chunks received.")
  394. actual_duration = 0
  395. return total_request_latency, first_chunk_latency, actual_duration
  396. async def send_streaming(
  397. manifest_item_list: list,
  398. name: str,
  399. server_url: str, # Changed from sync_triton_client
  400. protocol_client: types.ModuleType,
  401. log_interval: int,
  402. model_name: str,
  403. audio_save_dir: str = "./",
  404. save_sample_rate: int = 16000,
  405. chunk_overlap_duration: float = 0.1,
  406. padding_duration: int = None,
  407. use_spk2info_cache: bool = False,
  408. ):
  409. total_duration = 0.0
  410. latency_data = []
  411. task_id = int(name[5:])
  412. sync_triton_client = None # Initialize client variable
  413. try: # Wrap in try...finally to ensure client closing
  414. print(f"{name}: Initializing sync client for streaming...")
  415. sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
  416. print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
  417. for i, item in enumerate(manifest_item_list):
  418. if i % log_interval == 0:
  419. print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
  420. try:
  421. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  422. reference_text, target_text = item["reference_text"], item["target_text"]
  423. inputs, outputs = prepare_request_input_output(
  424. protocol_client,
  425. waveform,
  426. reference_text,
  427. target_text,
  428. sample_rate,
  429. padding_duration=padding_duration,
  430. use_spk2info_cache=use_spk2info_cache
  431. )
  432. request_id = str(uuid.uuid4())
  433. user_data = UserData()
  434. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  435. total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
  436. run_sync_streaming_inference,
  437. sync_triton_client,
  438. model_name,
  439. inputs,
  440. outputs,
  441. request_id,
  442. user_data,
  443. chunk_overlap_duration,
  444. save_sample_rate,
  445. audio_save_path
  446. )
  447. if total_request_latency is not None:
  448. print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
  449. latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
  450. total_duration += actual_duration
  451. else:
  452. print(f"{name}: Item {i} failed.")
  453. except FileNotFoundError:
  454. print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
  455. except Exception as e:
  456. print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
  457. import traceback
  458. traceback.print_exc()
  459. finally: # Ensure client is closed
  460. if sync_triton_client:
  461. try:
  462. print(f"{name}: Closing sync client...")
  463. sync_triton_client.close()
  464. except Exception as e:
  465. print(f"{name}: Error closing sync client: {e}")
  466. print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
  467. return total_duration, latency_data
  468. async def send(
  469. manifest_item_list: list,
  470. name: str,
  471. triton_client: tritonclient.grpc.aio.InferenceServerClient,
  472. protocol_client: types.ModuleType,
  473. log_interval: int,
  474. model_name: str,
  475. padding_duration: int = None,
  476. audio_save_dir: str = "./",
  477. save_sample_rate: int = 16000,
  478. use_spk2info_cache: bool = False,
  479. ):
  480. total_duration = 0.0
  481. latency_data = []
  482. task_id = int(name[5:])
  483. print(f"manifest_item_list: {manifest_item_list}")
  484. for i, item in enumerate(manifest_item_list):
  485. if i % log_interval == 0:
  486. print(f"{name}: {i}/{len(manifest_item_list)}")
  487. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  488. reference_text, target_text = item["reference_text"], item["target_text"]
  489. inputs, outputs = prepare_request_input_output(
  490. protocol_client,
  491. waveform,
  492. reference_text,
  493. target_text,
  494. sample_rate,
  495. padding_duration=padding_duration,
  496. use_spk2info_cache=use_spk2info_cache
  497. )
  498. sequence_id = 100000000 + i + task_id * 10
  499. start = time.time()
  500. response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
  501. audio = response.as_numpy("waveform").reshape(-1)
  502. actual_duration = len(audio) / save_sample_rate
  503. end = time.time() - start
  504. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  505. sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
  506. latency_data.append((end, actual_duration))
  507. total_duration += actual_duration
  508. return total_duration, latency_data
  509. def load_manifests(manifest_path):
  510. with open(manifest_path, "r") as f:
  511. manifest_list = []
  512. for line in f:
  513. assert len(line.strip().split("|")) == 4
  514. utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
  515. utt = Path(utt).stem
  516. # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
  517. if not os.path.isabs(prompt_wav):
  518. prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
  519. manifest_list.append(
  520. {
  521. "audio_filepath": prompt_wav,
  522. "reference_text": prompt_text,
  523. "target_text": gt_text,
  524. "target_audio_path": utt,
  525. }
  526. )
  527. return manifest_list
  528. def split_data(data, k):
  529. n = len(data)
  530. if n < k:
  531. print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
  532. k = n
  533. quotient = n // k
  534. remainder = n % k
  535. result = []
  536. start = 0
  537. for i in range(k):
  538. if i < remainder:
  539. end = start + quotient + 1
  540. else:
  541. end = start + quotient
  542. result.append(data[start:end])
  543. start = end
  544. return result
  545. async def main():
  546. args = get_args()
  547. url = f"{args.server_addr}:{args.server_port}"
  548. # --- Client Initialization based on mode ---
  549. triton_client = None
  550. protocol_client = None
  551. if args.mode == "offline":
  552. print("Initializing gRPC client for offline mode...")
  553. # Use the async client for offline tasks
  554. triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  555. protocol_client = grpcclient_aio
  556. elif args.mode == "streaming":
  557. print("Initializing gRPC client for streaming mode...")
  558. # Use the sync client for streaming tasks, handled via asyncio.to_thread
  559. # We will create one sync client instance PER TASK inside send_streaming.
  560. # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
  561. protocol_client = grpcclient_sync # protocol client for input prep
  562. else:
  563. raise ValueError(f"Invalid mode: {args.mode}")
  564. # --- End Client Initialization ---
  565. if args.reference_audio:
  566. args.num_tasks = 1
  567. args.log_interval = 1
  568. manifest_item_list = [
  569. {
  570. "reference_text": args.reference_text,
  571. "target_text": args.target_text,
  572. "audio_filepath": args.reference_audio,
  573. "target_audio_path": "test",
  574. }
  575. ]
  576. elif args.huggingface_dataset:
  577. import datasets
  578. dataset = datasets.load_dataset(
  579. args.huggingface_dataset,
  580. split=args.split_name,
  581. trust_remote_code=True,
  582. )
  583. manifest_item_list = []
  584. tmp_audio_path="./asset_zero_shot_prompt.wav"
  585. tmp_audio_text="希望你以后能够做的比我还好呦。"
  586. for i in range(len(dataset)):
  587. manifest_item_list.append(
  588. {
  589. "audio_filepath": dataset[i]["prompt_audio"],
  590. "reference_text": dataset[i]["prompt_text"],
  591. # "audio_filepath": tmp_audio_path,
  592. # "reference_text": tmp_audio_text,
  593. "target_audio_path": dataset[i]["id"],
  594. "target_text": dataset[i]["target_text"],
  595. }
  596. )
  597. else:
  598. manifest_item_list = load_manifests(args.manifest_path)
  599. num_tasks = min(args.num_tasks, len(manifest_item_list))
  600. manifest_item_list = split_data(manifest_item_list, num_tasks)
  601. os.makedirs(args.log_dir, exist_ok=True)
  602. args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
  603. tasks = []
  604. start_time = time.time()
  605. for i in range(num_tasks):
  606. # --- Task Creation based on mode ---
  607. if args.mode == "offline":
  608. task = asyncio.create_task(
  609. send(
  610. manifest_item_list[i],
  611. name=f"task-{i}",
  612. triton_client=triton_client,
  613. protocol_client=protocol_client,
  614. log_interval=args.log_interval,
  615. model_name=args.model_name,
  616. audio_save_dir=args.log_dir,
  617. padding_duration=1,
  618. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  619. use_spk2info_cache=args.use_spk2info_cache,
  620. )
  621. )
  622. elif args.mode == "streaming":
  623. task = asyncio.create_task(
  624. send_streaming(
  625. manifest_item_list[i],
  626. name=f"task-{i}",
  627. server_url=url, # Pass URL instead of client
  628. protocol_client=protocol_client,
  629. log_interval=args.log_interval,
  630. model_name=args.model_name,
  631. audio_save_dir=args.log_dir,
  632. padding_duration=10,
  633. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  634. chunk_overlap_duration=args.chunk_overlap_duration,
  635. use_spk2info_cache=args.use_spk2info_cache,
  636. )
  637. )
  638. # --- End Task Creation ---
  639. tasks.append(task)
  640. ans_list = await asyncio.gather(*tasks)
  641. end_time = time.time()
  642. elapsed = end_time - start_time
  643. total_duration = 0.0
  644. latency_data = []
  645. for ans in ans_list:
  646. if ans:
  647. total_duration += ans[0]
  648. latency_data.extend(ans[1]) # Use extend for list of lists
  649. else:
  650. print("Warning: A task returned None, possibly due to an error.")
  651. if total_duration == 0:
  652. print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
  653. rtf = float('inf')
  654. else:
  655. rtf = elapsed / total_duration
  656. s = f"Mode: {args.mode}\n"
  657. s += f"RTF: {rtf:.4f}\n"
  658. s += f"total_duration: {total_duration:.3f} seconds\n"
  659. s += f"({total_duration / 3600:.2f} hours)\n"
  660. s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
  661. # --- Statistics Reporting based on mode ---
  662. if latency_data:
  663. if args.mode == "offline":
  664. # Original offline latency calculation
  665. latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
  666. if latency_list:
  667. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  668. latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
  669. s += f"latency_variance: {latency_variance:.2f}\n"
  670. s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
  671. s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
  672. s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
  673. s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
  674. s += f"average_latency_ms: {latency_ms:.2f}\n"
  675. else:
  676. s += "No latency data collected for offline mode.\n"
  677. elif args.mode == "streaming":
  678. # Calculate stats for total request latency and first chunk latency
  679. total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
  680. first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
  681. s += "\n--- Total Request Latency ---\n"
  682. if total_latency_list:
  683. avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
  684. variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
  685. s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
  686. s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
  687. s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
  688. s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
  689. s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
  690. s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
  691. else:
  692. s += "No total request latency data collected.\n"
  693. s += "\n--- First Chunk Latency ---\n"
  694. if first_chunk_latency_list:
  695. avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
  696. variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
  697. s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
  698. s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
  699. s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
  700. s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
  701. s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
  702. s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
  703. else:
  704. s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
  705. else:
  706. s += "No latency data collected.\n"
  707. # --- End Statistics Reporting ---
  708. print(s)
  709. if args.manifest_path:
  710. name = Path(args.manifest_path).stem
  711. elif args.split_name:
  712. name = args.split_name
  713. elif args.reference_audio:
  714. name = Path(args.reference_audio).stem
  715. else:
  716. name = "results" # Default name if no manifest/split/audio provided
  717. with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
  718. f.write(s)
  719. # --- Statistics Fetching using temporary Async Client ---
  720. # Use a separate async client for fetching stats regardless of mode
  721. stats_client = None
  722. try:
  723. print("Initializing temporary async client for fetching stats...")
  724. stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  725. print("Fetching inference statistics...")
  726. # Fetching for all models, filtering might be needed depending on server setup
  727. stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
  728. print("Fetching model config...")
  729. metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
  730. write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
  731. with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
  732. json.dump(metadata, f, indent=4)
  733. except Exception as e:
  734. print(f"Could not retrieve statistics or config: {e}")
  735. finally:
  736. if stats_client:
  737. try:
  738. print("Closing temporary async stats client...")
  739. await stats_client.close()
  740. except Exception as e:
  741. print(f"Error closing async stats client: {e}")
  742. # --- End Statistics Fetching ---
  743. if __name__ == "__main__":
  744. # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
  745. async def run_main():
  746. try:
  747. await main()
  748. except Exception as e:
  749. print(f"An error occurred in main: {e}")
  750. import traceback
  751. traceback.print_exc()
  752. asyncio.run(run_main())