client_grpc.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
  2. # 2023 Nvidia (authors: Yuekai Zhang)
  3. # 2023 Recurrent.ai (authors: Songtao Shi)
  4. # See LICENSE for clarification regarding multiple authors
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """
  18. This script supports to load dataset from huggingface and sends it to the server
  19. for decoding, in parallel.
  20. Usage:
  21. num_task=2
  22. # For offline F5-TTS
  23. python3 client_grpc.py \
  24. --server-addr localhost \
  25. --model-name f5_tts \
  26. --num-tasks $num_task \
  27. --huggingface-dataset yuekai/seed_tts \
  28. --split-name test_zh \
  29. --log-dir ./log_concurrent_tasks_${num_task}
  30. # For offline Spark-TTS-0.5B
  31. python3 client_grpc.py \
  32. --server-addr localhost \
  33. --model-name spark_tts \
  34. --num-tasks $num_task \
  35. --huggingface-dataset yuekai/seed_tts \
  36. --split-name wenetspeech4tts \
  37. --log-dir ./log_concurrent_tasks_${num_task}
  38. """
  39. import argparse
  40. import asyncio
  41. import json
  42. import queue
  43. import uuid
  44. import functools
  45. import os
  46. import time
  47. import types
  48. from pathlib import Path
  49. import numpy as np
  50. import soundfile as sf
  51. import tritonclient
  52. import tritonclient.grpc.aio as grpcclient_aio
  53. import tritonclient.grpc as grpcclient_sync
  54. from tritonclient.utils import np_to_triton_dtype, InferenceServerException
  55. class UserData:
  56. def __init__(self):
  57. self._completed_requests = queue.Queue()
  58. self._first_chunk_time = None
  59. self._second_chunk_time = None
  60. self._start_time = None
  61. def record_start_time(self):
  62. self._start_time = time.time()
  63. def get_first_chunk_latency(self):
  64. if self._first_chunk_time and self._start_time:
  65. return self._first_chunk_time - self._start_time
  66. return None
  67. def get_second_chunk_latency(self):
  68. if self._first_chunk_time and self._second_chunk_time:
  69. return self._second_chunk_time - self._first_chunk_time
  70. return None
  71. def callback(user_data, result, error):
  72. if not error:
  73. if user_data._first_chunk_time is None:
  74. user_data._first_chunk_time = time.time()
  75. elif user_data._second_chunk_time is None:
  76. user_data._second_chunk_time = time.time()
  77. if error:
  78. user_data._completed_requests.put(error)
  79. else:
  80. user_data._completed_requests.put(result)
  81. def stream_callback(user_data_map, result, error):
  82. request_id = None
  83. if error:
  84. print(f"An error occurred in the stream callback: {error}")
  85. else:
  86. request_id = result.get_response().id
  87. if request_id:
  88. user_data = user_data_map.get(request_id)
  89. if user_data:
  90. callback(user_data, result, error)
  91. else:
  92. print(f"Warning: Could not find user_data for request_id {request_id}")
  93. def write_triton_stats(stats, summary_file):
  94. with open(summary_file, "w") as summary_f:
  95. model_stats = stats["model_stats"]
  96. for model_state in model_stats:
  97. if "last_inference" not in model_state:
  98. continue
  99. summary_f.write(f"model name is {model_state['name']} \n")
  100. model_inference_stats = model_state["inference_stats"]
  101. total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
  102. total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
  103. total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
  104. total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
  105. summary_f.write(
  106. f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n"
  107. )
  108. model_batch_stats = model_state["batch_stats"]
  109. for batch in model_batch_stats:
  110. batch_size = int(batch["batch_size"])
  111. compute_input = batch["compute_input"]
  112. compute_output = batch["compute_output"]
  113. compute_infer = batch["compute_infer"]
  114. batch_count = int(compute_infer["count"])
  115. assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
  116. compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
  117. compute_input_time_ms = int(compute_input["ns"]) / 1e6
  118. compute_output_time_ms = int(compute_output["ns"]) / 1e6
  119. summary_f.write(
  120. f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
  121. )
  122. summary_f.write(
  123. f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
  124. )
  125. summary_f.write(
  126. f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
  127. )
  128. def subtract_stats(stats_after, stats_before):
  129. """Subtracts two Triton inference statistics objects."""
  130. stats_diff = json.loads(json.dumps(stats_after))
  131. model_stats_before_map = {
  132. s["name"]: {
  133. "version": s["version"],
  134. "last_inference": s.get("last_inference", 0),
  135. "inference_count": s.get("inference_count", 0),
  136. "execution_count": s.get("execution_count", 0),
  137. "inference_stats": s.get("inference_stats", {}),
  138. "batch_stats": s.get("batch_stats", []),
  139. }
  140. for s in stats_before["model_stats"]
  141. }
  142. for model_stat_after in stats_diff["model_stats"]:
  143. model_name = model_stat_after["name"]
  144. if model_name in model_stats_before_map:
  145. model_stat_before = model_stats_before_map[model_name]
  146. model_stat_after["inference_count"] = str(
  147. int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
  148. )
  149. model_stat_after["execution_count"] = str(
  150. int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
  151. )
  152. if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
  153. for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
  154. if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
  155. if "ns" in model_stat_after["inference_stats"][key]:
  156. ns_after = int(model_stat_after["inference_stats"][key]["ns"])
  157. ns_before = int(model_stat_before["inference_stats"][key]["ns"])
  158. model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
  159. if "count" in model_stat_after["inference_stats"][key]:
  160. count_after = int(model_stat_after["inference_stats"][key]["count"])
  161. count_before = int(model_stat_before["inference_stats"][key]["count"])
  162. model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
  163. if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
  164. batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
  165. for batch_stat_after in model_stat_after["batch_stats"]:
  166. bs = batch_stat_after["batch_size"]
  167. if bs in batch_stats_before_map:
  168. batch_stat_before = batch_stats_before_map[bs]
  169. for key in ["compute_input", "compute_infer", "compute_output"]:
  170. if key in batch_stat_after and key in batch_stat_before:
  171. count_after = int(batch_stat_after[key]["count"])
  172. count_before = int(batch_stat_before[key]["count"])
  173. batch_stat_after[key]["count"] = str(count_after - count_before)
  174. ns_after = int(batch_stat_after[key]["ns"])
  175. ns_before = int(batch_stat_before[key]["ns"])
  176. batch_stat_after[key]["ns"] = str(ns_after - ns_before)
  177. return stats_diff
  178. def get_args():
  179. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  180. parser.add_argument(
  181. "--server-addr",
  182. type=str,
  183. default="localhost",
  184. help="Address of the server",
  185. )
  186. parser.add_argument(
  187. "--server-port",
  188. type=int,
  189. default=8001,
  190. help="Grpc port of the triton server, default is 8001",
  191. )
  192. parser.add_argument(
  193. "--reference-audio",
  194. type=str,
  195. default=None,
  196. help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
  197. )
  198. parser.add_argument(
  199. "--reference-text",
  200. type=str,
  201. default="",
  202. help="",
  203. )
  204. parser.add_argument(
  205. "--target-text",
  206. type=str,
  207. default="",
  208. help="",
  209. )
  210. parser.add_argument(
  211. "--huggingface-dataset",
  212. type=str,
  213. default="yuekai/seed_tts",
  214. help="dataset name in huggingface dataset hub",
  215. )
  216. parser.add_argument(
  217. "--split-name",
  218. type=str,
  219. default="wenetspeech4tts",
  220. choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
  221. help="dataset split name, default is 'test'",
  222. )
  223. parser.add_argument(
  224. "--manifest-path",
  225. type=str,
  226. default=None,
  227. help="Path to the manifest dir which includes wav.scp trans.txt files.",
  228. )
  229. parser.add_argument(
  230. "--model-name",
  231. type=str,
  232. default="f5_tts",
  233. choices=[
  234. "f5_tts",
  235. "spark_tts",
  236. "cosyvoice2",
  237. "cosyvoice2_dit"],
  238. help="triton model_repo module name to request",
  239. )
  240. parser.add_argument(
  241. "--num-tasks",
  242. type=int,
  243. default=1,
  244. help="Number of concurrent tasks for sending",
  245. )
  246. parser.add_argument(
  247. "--log-interval",
  248. type=int,
  249. default=5,
  250. help="Controls how frequently we print the log.",
  251. )
  252. parser.add_argument(
  253. "--compute-wer",
  254. action="store_true",
  255. default=False,
  256. help="""True to compute WER.
  257. """,
  258. )
  259. parser.add_argument(
  260. "--log-dir",
  261. type=str,
  262. required=False,
  263. default="./tmp",
  264. help="log directory",
  265. )
  266. parser.add_argument(
  267. "--mode",
  268. type=str,
  269. default="offline",
  270. choices=["offline", "streaming"],
  271. help="Select offline or streaming benchmark mode."
  272. )
  273. parser.add_argument(
  274. "--chunk-overlap-duration",
  275. type=float,
  276. default=0.1,
  277. help="Chunk overlap duration for streaming reconstruction (in seconds)."
  278. )
  279. parser.add_argument(
  280. "--use-spk2info-cache",
  281. type=str,
  282. default="False",
  283. help="Use spk2info cache for reference audio.",
  284. )
  285. return parser.parse_args()
  286. def load_audio(wav_path, target_sample_rate=16000):
  287. assert target_sample_rate == 16000, "hard coding in server"
  288. if isinstance(wav_path, dict):
  289. waveform = wav_path["array"]
  290. sample_rate = wav_path["sampling_rate"]
  291. else:
  292. waveform, sample_rate = sf.read(wav_path)
  293. if sample_rate != target_sample_rate:
  294. from scipy.signal import resample
  295. num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
  296. waveform = resample(waveform, num_samples)
  297. return waveform, target_sample_rate
  298. def prepare_request_input_output(
  299. protocol_client,
  300. waveform,
  301. reference_text,
  302. target_text,
  303. sample_rate=16000,
  304. padding_duration: int = None,
  305. use_spk2info_cache: bool = False
  306. ):
  307. """Prepares inputs for Triton inference (offline or streaming)."""
  308. assert len(waveform.shape) == 1, "waveform should be 1D"
  309. lengths = np.array([[len(waveform)]], dtype=np.int32)
  310. if padding_duration:
  311. duration = len(waveform) / sample_rate
  312. if reference_text:
  313. estimated_target_duration = duration / len(reference_text) * len(target_text)
  314. else:
  315. estimated_target_duration = duration
  316. required_total_samples = padding_duration * sample_rate * (
  317. (int(estimated_target_duration + duration) // padding_duration) + 1
  318. )
  319. samples = np.zeros((1, required_total_samples), dtype=np.float32)
  320. samples[0, : len(waveform)] = waveform
  321. else:
  322. samples = waveform.reshape(1, -1).astype(np.float32)
  323. inputs = [
  324. protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
  325. protocol_client.InferInput(
  326. "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
  327. ),
  328. protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
  329. protocol_client.InferInput("target_text", [1, 1], "BYTES"),
  330. ]
  331. inputs[0].set_data_from_numpy(samples)
  332. inputs[1].set_data_from_numpy(lengths)
  333. input_data_numpy = np.array([reference_text], dtype=object)
  334. input_data_numpy = input_data_numpy.reshape((1, 1))
  335. inputs[2].set_data_from_numpy(input_data_numpy)
  336. input_data_numpy = np.array([target_text], dtype=object)
  337. input_data_numpy = input_data_numpy.reshape((1, 1))
  338. inputs[3].set_data_from_numpy(input_data_numpy)
  339. outputs = [protocol_client.InferRequestedOutput("waveform")]
  340. if use_spk2info_cache:
  341. inputs = inputs[-1:]
  342. return inputs, outputs
  343. def run_sync_streaming_inference(
  344. sync_triton_client: tritonclient.grpc.InferenceServerClient,
  345. model_name: str,
  346. inputs: list,
  347. outputs: list,
  348. request_id: str,
  349. user_data: UserData,
  350. chunk_overlap_duration: float,
  351. save_sample_rate: int,
  352. audio_save_path: str,
  353. ):
  354. """Helper function to run the blocking sync streaming call."""
  355. start_time_total = time.time()
  356. user_data.record_start_time()
  357. sync_triton_client.async_stream_infer(
  358. model_name,
  359. inputs,
  360. request_id=request_id,
  361. outputs=outputs,
  362. enable_empty_final_response=True,
  363. )
  364. audios = []
  365. while True:
  366. try:
  367. result = user_data._completed_requests.get(timeout=200)
  368. if isinstance(result, InferenceServerException):
  369. print(f"Received InferenceServerException: {result}")
  370. return None, None, None, None
  371. response = result.get_response()
  372. final = response.parameters["triton_final_response"].bool_param
  373. if final is True:
  374. break
  375. audio_chunk = result.as_numpy("waveform").reshape(-1)
  376. if audio_chunk.size > 0:
  377. audios.append(audio_chunk)
  378. else:
  379. print("Warning: received empty audio chunk.")
  380. except queue.Empty:
  381. print(f"Timeout waiting for response for request id {request_id}")
  382. return None, None, None, None
  383. end_time_total = time.time()
  384. total_request_latency = end_time_total - start_time_total
  385. first_chunk_latency = user_data.get_first_chunk_latency()
  386. second_chunk_latency = user_data.get_second_chunk_latency()
  387. if audios:
  388. if model_name == "spark_tts":
  389. cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
  390. fade_out = np.linspace(1, 0, cross_fade_samples)
  391. fade_in = np.linspace(0, 1, cross_fade_samples)
  392. reconstructed_audio = None
  393. if not audios:
  394. print("Warning: No audio chunks received.")
  395. reconstructed_audio = np.array([], dtype=np.float32)
  396. elif len(audios) == 1:
  397. reconstructed_audio = audios[0]
  398. else:
  399. reconstructed_audio = audios[0][:-cross_fade_samples]
  400. for i in range(1, len(audios)):
  401. cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
  402. audios[i - 1][-cross_fade_samples:] * fade_out)
  403. middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
  404. reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
  405. reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
  406. if reconstructed_audio is not None and reconstructed_audio.size > 0:
  407. actual_duration = len(reconstructed_audio) / save_sample_rate
  408. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  409. else:
  410. print("Warning: No audio chunks received or reconstructed.")
  411. actual_duration = 0
  412. else:
  413. reconstructed_audio = np.concatenate(audios)
  414. actual_duration = len(reconstructed_audio) / save_sample_rate
  415. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  416. else:
  417. print("Warning: No audio chunks received.")
  418. actual_duration = 0
  419. return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
  420. async def send_streaming(
  421. manifest_item_list: list,
  422. name: str,
  423. server_url: str,
  424. protocol_client: types.ModuleType,
  425. log_interval: int,
  426. model_name: str,
  427. audio_save_dir: str = "./",
  428. save_sample_rate: int = 16000,
  429. chunk_overlap_duration: float = 0.1,
  430. padding_duration: int = None,
  431. use_spk2info_cache: bool = False,
  432. ):
  433. total_duration = 0.0
  434. latency_data = []
  435. task_id = int(name[5:])
  436. sync_triton_client = None
  437. user_data_map = {}
  438. try:
  439. print(f"{name}: Initializing sync client for streaming...")
  440. sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
  441. sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
  442. print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
  443. for i, item in enumerate(manifest_item_list):
  444. if i % log_interval == 0:
  445. print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
  446. try:
  447. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  448. reference_text, target_text = item["reference_text"], item["target_text"]
  449. inputs, outputs = prepare_request_input_output(
  450. protocol_client,
  451. waveform,
  452. reference_text,
  453. target_text,
  454. sample_rate,
  455. padding_duration=padding_duration,
  456. use_spk2info_cache=use_spk2info_cache
  457. )
  458. request_id = str(uuid.uuid4())
  459. user_data = UserData()
  460. user_data_map[request_id] = user_data
  461. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  462. total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
  463. run_sync_streaming_inference,
  464. sync_triton_client,
  465. model_name,
  466. inputs,
  467. outputs,
  468. request_id,
  469. user_data,
  470. chunk_overlap_duration,
  471. save_sample_rate,
  472. audio_save_path
  473. )
  474. if total_request_latency is not None:
  475. print(
  476. f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
  477. f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
  478. f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
  479. )
  480. latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
  481. total_duration += actual_duration
  482. else:
  483. print(f"{name}: Item {i} failed.")
  484. del user_data_map[request_id]
  485. except FileNotFoundError:
  486. print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
  487. except Exception as e:
  488. print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
  489. import traceback
  490. traceback.print_exc()
  491. finally:
  492. if sync_triton_client:
  493. try:
  494. print(f"{name}: Closing stream and sync client...")
  495. sync_triton_client.stop_stream()
  496. sync_triton_client.close()
  497. except Exception as e:
  498. print(f"{name}: Error closing sync client: {e}")
  499. print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
  500. return total_duration, latency_data
  501. async def send(
  502. manifest_item_list: list,
  503. name: str,
  504. triton_client: tritonclient.grpc.aio.InferenceServerClient,
  505. protocol_client: types.ModuleType,
  506. log_interval: int,
  507. model_name: str,
  508. padding_duration: int = None,
  509. audio_save_dir: str = "./",
  510. save_sample_rate: int = 16000,
  511. use_spk2info_cache: bool = False,
  512. ):
  513. total_duration = 0.0
  514. latency_data = []
  515. task_id = int(name[5:])
  516. for i, item in enumerate(manifest_item_list):
  517. if i % log_interval == 0:
  518. print(f"{name}: {i}/{len(manifest_item_list)}")
  519. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  520. reference_text, target_text = item["reference_text"], item["target_text"]
  521. inputs, outputs = prepare_request_input_output(
  522. protocol_client,
  523. waveform,
  524. reference_text,
  525. target_text,
  526. sample_rate,
  527. padding_duration=padding_duration,
  528. use_spk2info_cache=use_spk2info_cache
  529. )
  530. sequence_id = 100000000 + i + task_id * 10
  531. start = time.time()
  532. response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
  533. audio = response.as_numpy("waveform").reshape(-1)
  534. actual_duration = len(audio) / save_sample_rate
  535. end = time.time() - start
  536. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  537. sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
  538. latency_data.append((end, actual_duration))
  539. total_duration += actual_duration
  540. return total_duration, latency_data
  541. def load_manifests(manifest_path):
  542. with open(manifest_path, "r") as f:
  543. manifest_list = []
  544. for line in f:
  545. assert len(line.strip().split("|")) == 4
  546. utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
  547. utt = Path(utt).stem
  548. if not os.path.isabs(prompt_wav):
  549. prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
  550. manifest_list.append(
  551. {
  552. "audio_filepath": prompt_wav,
  553. "reference_text": prompt_text,
  554. "target_text": gt_text,
  555. "target_audio_path": utt,
  556. }
  557. )
  558. return manifest_list
  559. def split_data(data, k):
  560. n = len(data)
  561. if n < k:
  562. print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
  563. k = n
  564. quotient = n // k
  565. remainder = n % k
  566. result = []
  567. start = 0
  568. for i in range(k):
  569. if i < remainder:
  570. end = start + quotient + 1
  571. else:
  572. end = start + quotient
  573. result.append(data[start:end])
  574. start = end
  575. return result
  576. async def main():
  577. args = get_args()
  578. url = f"{args.server_addr}:{args.server_port}"
  579. triton_client = None
  580. protocol_client = None
  581. if args.mode == "offline":
  582. print("Initializing gRPC client for offline mode...")
  583. triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  584. protocol_client = grpcclient_aio
  585. elif args.mode == "streaming":
  586. print("Initializing gRPC client for streaming mode...")
  587. protocol_client = grpcclient_sync
  588. else:
  589. raise ValueError(f"Invalid mode: {args.mode}")
  590. if args.reference_audio:
  591. args.num_tasks = 1
  592. args.log_interval = 1
  593. manifest_item_list = [
  594. {
  595. "reference_text": args.reference_text,
  596. "target_text": args.target_text,
  597. "audio_filepath": args.reference_audio,
  598. "target_audio_path": "test",
  599. }
  600. ]
  601. elif args.huggingface_dataset:
  602. import datasets
  603. dataset = datasets.load_dataset(
  604. args.huggingface_dataset,
  605. split=args.split_name,
  606. trust_remote_code=True,
  607. )
  608. manifest_item_list = []
  609. for i in range(len(dataset)):
  610. manifest_item_list.append(
  611. {
  612. "audio_filepath": dataset[i]["prompt_audio"],
  613. "reference_text": dataset[i]["prompt_text"],
  614. "target_audio_path": dataset[i]["id"],
  615. "target_text": dataset[i]["target_text"],
  616. }
  617. )
  618. else:
  619. manifest_item_list = load_manifests(args.manifest_path)
  620. stats_client = None
  621. stats_before = None
  622. try:
  623. print("Initializing temporary async client for fetching stats...")
  624. stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  625. print("Fetching inference statistics before running tasks...")
  626. stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
  627. except Exception as e:
  628. print(f"Could not retrieve statistics before running tasks: {e}")
  629. num_tasks = min(args.num_tasks, len(manifest_item_list))
  630. manifest_item_list = split_data(manifest_item_list, num_tasks)
  631. os.makedirs(args.log_dir, exist_ok=True)
  632. args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
  633. tasks = []
  634. start_time = time.time()
  635. for i in range(num_tasks):
  636. if args.mode == "offline":
  637. task = asyncio.create_task(
  638. send(
  639. manifest_item_list[i],
  640. name=f"task-{i}",
  641. triton_client=triton_client,
  642. protocol_client=protocol_client,
  643. log_interval=args.log_interval,
  644. model_name=args.model_name,
  645. audio_save_dir=args.log_dir,
  646. padding_duration=1,
  647. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  648. use_spk2info_cache=args.use_spk2info_cache,
  649. )
  650. )
  651. elif args.mode == "streaming":
  652. task = asyncio.create_task(
  653. send_streaming(
  654. manifest_item_list[i],
  655. name=f"task-{i}",
  656. server_url=url,
  657. protocol_client=protocol_client,
  658. log_interval=args.log_interval,
  659. model_name=args.model_name,
  660. audio_save_dir=args.log_dir,
  661. padding_duration=10,
  662. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  663. chunk_overlap_duration=args.chunk_overlap_duration,
  664. use_spk2info_cache=args.use_spk2info_cache,
  665. )
  666. )
  667. tasks.append(task)
  668. ans_list = await asyncio.gather(*tasks)
  669. end_time = time.time()
  670. elapsed = end_time - start_time
  671. total_duration = 0.0
  672. latency_data = []
  673. for ans in ans_list:
  674. if ans:
  675. total_duration += ans[0]
  676. latency_data.extend(ans[1])
  677. else:
  678. print("Warning: A task returned None, possibly due to an error.")
  679. if total_duration == 0:
  680. print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
  681. rtf = float('inf')
  682. else:
  683. rtf = elapsed / total_duration
  684. s = f"Mode: {args.mode}\n"
  685. s += f"RTF: {rtf:.4f}\n"
  686. s += f"total_duration: {total_duration:.3f} seconds\n"
  687. s += f"({total_duration / 3600:.2f} hours)\n"
  688. s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
  689. if latency_data:
  690. if args.mode == "offline":
  691. latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
  692. if latency_list:
  693. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  694. latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
  695. s += f"latency_variance: {latency_variance:.2f}\n"
  696. s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
  697. s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
  698. s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
  699. s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
  700. s += f"average_latency_ms: {latency_ms:.2f}\n"
  701. else:
  702. s += "No latency data collected for offline mode.\n"
  703. elif args.mode == "streaming":
  704. total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
  705. first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
  706. second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
  707. s += "\n--- Total Request Latency ---\n"
  708. if total_latency_list:
  709. avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
  710. variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
  711. s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
  712. s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
  713. s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
  714. s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
  715. s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
  716. s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
  717. else:
  718. s += "No total request latency data collected.\n"
  719. s += "\n--- First Chunk Latency ---\n"
  720. if first_chunk_latency_list:
  721. avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
  722. variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
  723. s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
  724. s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
  725. s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
  726. s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
  727. s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
  728. s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
  729. else:
  730. s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
  731. s += "\n--- Second Chunk Latency ---\n"
  732. if second_chunk_latency_list:
  733. avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
  734. variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
  735. s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
  736. s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
  737. s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
  738. s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
  739. s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
  740. s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
  741. else:
  742. s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
  743. else:
  744. s += "No latency data collected.\n"
  745. print(s)
  746. if args.manifest_path:
  747. name = Path(args.manifest_path).stem
  748. elif args.split_name:
  749. name = args.split_name
  750. elif args.reference_audio:
  751. name = Path(args.reference_audio).stem
  752. else:
  753. name = "results"
  754. with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
  755. f.write(s)
  756. try:
  757. if stats_client and stats_before:
  758. print("Fetching inference statistics after running tasks...")
  759. stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
  760. print("Calculating statistics difference...")
  761. stats = subtract_stats(stats_after, stats_before)
  762. print("Fetching model config...")
  763. metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
  764. write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
  765. with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
  766. json.dump(metadata, f, indent=4)
  767. else:
  768. print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
  769. except Exception as e:
  770. print(f"Could not retrieve statistics or config: {e}")
  771. finally:
  772. if stats_client:
  773. try:
  774. print("Closing temporary async stats client...")
  775. await stats_client.close()
  776. except Exception as e:
  777. print(f"Error closing async stats client: {e}")
  778. if __name__ == "__main__":
  779. async def run_main():
  780. try:
  781. await main()
  782. except Exception as e:
  783. print(f"An error occurred in main: {e}")
  784. import traceback
  785. traceback.print_exc()
  786. asyncio.run(run_main())