client_grpc.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996
  1. # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
  2. # 2023 Nvidia (authors: Yuekai Zhang)
  3. # 2023 Recurrent.ai (authors: Songtao Shi)
  4. # See LICENSE for clarification regarding multiple authors
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """
  18. This script supports to load dataset from huggingface and sends it to the server
  19. for decoding, in parallel.
  20. Usage:
  21. num_task=2
  22. # For offline F5-TTS
  23. python3 client_grpc.py \
  24. --server-addr localhost \
  25. --model-name f5_tts \
  26. --num-tasks $num_task \
  27. --huggingface-dataset yuekai/seed_tts \
  28. --split-name test_zh \
  29. --log-dir ./log_concurrent_tasks_${num_task}
  30. # For offline Spark-TTS-0.5B
  31. python3 client_grpc.py \
  32. --server-addr localhost \
  33. --model-name spark_tts \
  34. --num-tasks $num_task \
  35. --huggingface-dataset yuekai/seed_tts \
  36. --split-name wenetspeech4tts \
  37. --log-dir ./log_concurrent_tasks_${num_task}
  38. """
  39. import argparse
  40. import asyncio
  41. import json
  42. import queue # Added
  43. import uuid # Added
  44. import functools # Added
  45. import os
  46. import time
  47. import types
  48. from pathlib import Path
  49. import numpy as np
  50. import soundfile as sf
  51. import tritonclient
  52. import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
  53. import tritonclient.grpc as grpcclient_sync # Added sync client import
  54. from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
  55. from datetime import datetime
  56. # --- Added UserData and callback ---
  57. class UserData:
  58. def __init__(self):
  59. self._completed_requests = queue.Queue()
  60. self._first_chunk_time = None
  61. self._second_chunk_time = None
  62. self._start_time = None
  63. def record_start_time(self):
  64. self._start_time = time.time()
  65. def get_first_chunk_latency(self):
  66. if self._first_chunk_time and self._start_time:
  67. return self._first_chunk_time - self._start_time
  68. return None
  69. def get_second_chunk_latency(self):
  70. if self._first_chunk_time and self._second_chunk_time:
  71. return self._second_chunk_time - self._first_chunk_time
  72. return None
  73. def callback(user_data, result, error):
  74. if not error:
  75. if user_data._first_chunk_time is None:
  76. user_data._first_chunk_time = time.time() # Record time of first successful chunk
  77. elif user_data._second_chunk_time is None:
  78. user_data._second_chunk_time = time.time()
  79. if error:
  80. user_data._completed_requests.put(error)
  81. else:
  82. user_data._completed_requests.put(result)
  83. def stream_callback(user_data_map, result, error):
  84. request_id = None
  85. if error:
  86. # Note: InferenceServerException doesn't have a public request_id() method in all versions.
  87. # This part might need adjustment depending on the tritonclient library version.
  88. # A more robust way would be to wrap the error with the request_id if possible.
  89. # For now, we assume we can't get request_id from error and it will timeout on the client side.
  90. print(f"An error occurred in the stream callback: {error}")
  91. else:
  92. request_id = result.get_response().id
  93. if request_id:
  94. user_data = user_data_map.get(request_id)
  95. if user_data:
  96. callback(user_data, result, error)
  97. else:
  98. print(f"Warning: Could not find user_data for request_id {request_id}")
  99. # --- End Added UserData and callback ---
  100. def write_triton_stats(stats, summary_file):
  101. with open(summary_file, "w") as summary_f:
  102. model_stats = stats["model_stats"]
  103. # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
  104. summary_f.write(
  105. "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
  106. )
  107. summary_f.write("To learn more about the log, please refer to: \n")
  108. summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
  109. summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
  110. summary_f.write(
  111. "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
  112. )
  113. summary_f.write(
  114. "However, there is a trade-off between the increased queue time and the increased batch size. \n"
  115. )
  116. summary_f.write(
  117. "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
  118. )
  119. summary_f.write(
  120. "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
  121. )
  122. for model_state in model_stats:
  123. if "last_inference" not in model_state:
  124. continue
  125. summary_f.write(f"model name is {model_state['name']} \n")
  126. model_inference_stats = model_state["inference_stats"]
  127. total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
  128. total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
  129. total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
  130. total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
  131. summary_f.write(
  132. f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
  133. )
  134. model_batch_stats = model_state["batch_stats"]
  135. for batch in model_batch_stats:
  136. batch_size = int(batch["batch_size"])
  137. compute_input = batch["compute_input"]
  138. compute_output = batch["compute_output"]
  139. compute_infer = batch["compute_infer"]
  140. batch_count = int(compute_infer["count"])
  141. assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
  142. compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
  143. compute_input_time_ms = int(compute_input["ns"]) / 1e6
  144. compute_output_time_ms = int(compute_output["ns"]) / 1e6
  145. summary_f.write(
  146. f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
  147. )
  148. summary_f.write(
  149. f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
  150. )
  151. summary_f.write(
  152. f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
  153. )
  154. def subtract_stats(stats_after, stats_before):
  155. """Subtracts two Triton inference statistics objects."""
  156. # Deep copy to avoid modifying the original stats_after
  157. stats_diff = json.loads(json.dumps(stats_after))
  158. model_stats_before_map = {
  159. s["name"]: {
  160. "version": s["version"],
  161. "last_inference": s.get("last_inference", 0),
  162. "inference_count": s.get("inference_count", 0),
  163. "execution_count": s.get("execution_count", 0),
  164. "inference_stats": s.get("inference_stats", {}),
  165. "batch_stats": s.get("batch_stats", []),
  166. }
  167. for s in stats_before["model_stats"]
  168. }
  169. for model_stat_after in stats_diff["model_stats"]:
  170. model_name = model_stat_after["name"]
  171. if model_name in model_stats_before_map:
  172. model_stat_before = model_stats_before_map[model_name]
  173. # Subtract counts
  174. model_stat_after["inference_count"] = str(
  175. int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
  176. )
  177. model_stat_after["execution_count"] = str(
  178. int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
  179. )
  180. # Subtract aggregate stats (like queue, compute times)
  181. if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
  182. for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
  183. if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
  184. if "ns" in model_stat_after["inference_stats"][key]:
  185. ns_after = int(model_stat_after["inference_stats"][key]["ns"])
  186. ns_before = int(model_stat_before["inference_stats"][key]["ns"])
  187. model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
  188. if "count" in model_stat_after["inference_stats"][key]:
  189. count_after = int(model_stat_after["inference_stats"][key]["count"])
  190. count_before = int(model_stat_before["inference_stats"][key]["count"])
  191. model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
  192. # Subtract batch execution stats
  193. if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
  194. batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
  195. for batch_stat_after in model_stat_after["batch_stats"]:
  196. bs = batch_stat_after["batch_size"]
  197. if bs in batch_stats_before_map:
  198. batch_stat_before = batch_stats_before_map[bs]
  199. for key in ["compute_input", "compute_infer", "compute_output"]:
  200. if key in batch_stat_after and key in batch_stat_before:
  201. count_after = int(batch_stat_after[key]["count"])
  202. count_before = int(batch_stat_before[key]["count"])
  203. batch_stat_after[key]["count"] = str(count_after - count_before)
  204. ns_after = int(batch_stat_after[key]["ns"])
  205. ns_before = int(batch_stat_before[key]["ns"])
  206. batch_stat_after[key]["ns"] = str(ns_after - ns_before)
  207. return stats_diff
  208. def get_args():
  209. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  210. parser.add_argument(
  211. "--server-addr",
  212. type=str,
  213. default="localhost",
  214. help="Address of the server",
  215. )
  216. parser.add_argument(
  217. "--server-port",
  218. type=int,
  219. default=8001,
  220. help="Grpc port of the triton server, default is 8001",
  221. )
  222. parser.add_argument(
  223. "--reference-audio",
  224. type=str,
  225. default=None,
  226. help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
  227. )
  228. parser.add_argument(
  229. "--reference-text",
  230. type=str,
  231. default="",
  232. help="",
  233. )
  234. parser.add_argument(
  235. "--target-text",
  236. type=str,
  237. default="",
  238. help="",
  239. )
  240. parser.add_argument(
  241. "--huggingface-dataset",
  242. type=str,
  243. default="yuekai/seed_tts",
  244. help="dataset name in huggingface dataset hub",
  245. )
  246. parser.add_argument(
  247. "--split-name",
  248. type=str,
  249. default="wenetspeech4tts",
  250. choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
  251. help="dataset split name, default is 'test'",
  252. )
  253. parser.add_argument(
  254. "--manifest-path",
  255. type=str,
  256. default=None,
  257. help="Path to the manifest dir which includes wav.scp trans.txt files.",
  258. )
  259. parser.add_argument(
  260. "--model-name",
  261. type=str,
  262. default="f5_tts",
  263. choices=[
  264. "f5_tts",
  265. "spark_tts",
  266. "cosyvoice2",
  267. "cosyvoice2_dit"],
  268. help="triton model_repo module name to request",
  269. )
  270. parser.add_argument(
  271. "--num-tasks",
  272. type=int,
  273. default=1,
  274. help="Number of concurrent tasks for sending",
  275. )
  276. parser.add_argument(
  277. "--log-interval",
  278. type=int,
  279. default=5,
  280. help="Controls how frequently we print the log.",
  281. )
  282. parser.add_argument(
  283. "--compute-wer",
  284. action="store_true",
  285. default=False,
  286. help="""True to compute WER.
  287. """,
  288. )
  289. parser.add_argument(
  290. "--log-dir",
  291. type=str,
  292. required=False,
  293. default="./tmp",
  294. help="log directory",
  295. )
  296. # --- Added arguments ---
  297. parser.add_argument(
  298. "--mode",
  299. type=str,
  300. default="offline",
  301. choices=["offline", "streaming"],
  302. help="Select offline or streaming benchmark mode."
  303. )
  304. parser.add_argument(
  305. "--chunk-overlap-duration",
  306. type=float,
  307. default=0.1,
  308. help="Chunk overlap duration for streaming reconstruction (in seconds)."
  309. )
  310. parser.add_argument(
  311. "--use-spk2info-cache",
  312. type=str,
  313. default="False",
  314. help="Use spk2info cache for reference audio.",
  315. )
  316. return parser.parse_args()
  317. def load_audio(wav_path, target_sample_rate=16000):
  318. assert target_sample_rate == 16000, "hard coding in server"
  319. if isinstance(wav_path, dict):
  320. waveform = wav_path["array"]
  321. sample_rate = wav_path["sampling_rate"]
  322. else:
  323. waveform, sample_rate = sf.read(wav_path)
  324. if sample_rate != target_sample_rate:
  325. from scipy.signal import resample
  326. num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
  327. waveform = resample(waveform, num_samples)
  328. return waveform, target_sample_rate
  329. def prepare_request_input_output(
  330. protocol_client, # Can be grpcclient_aio or grpcclient_sync
  331. waveform,
  332. reference_text,
  333. target_text,
  334. sample_rate=16000,
  335. padding_duration: int = None, # Optional padding for offline mode
  336. use_spk2info_cache: bool = False
  337. ):
  338. """Prepares inputs for Triton inference (offline or streaming)."""
  339. assert len(waveform.shape) == 1, "waveform should be 1D"
  340. lengths = np.array([[len(waveform)]], dtype=np.int32)
  341. # Apply padding only if padding_duration is provided (for offline)
  342. if padding_duration:
  343. duration = len(waveform) / sample_rate
  344. # Estimate target duration based on text length ratio (crude estimation)
  345. # Avoid division by zero if reference_text is empty
  346. if reference_text:
  347. estimated_target_duration = duration / len(reference_text) * len(target_text)
  348. else:
  349. estimated_target_duration = duration # Assume target duration similar to reference if no text
  350. # Calculate required samples based on estimated total duration
  351. required_total_samples = padding_duration * sample_rate * (
  352. (int(estimated_target_duration + duration) // padding_duration) + 1
  353. )
  354. samples = np.zeros((1, required_total_samples), dtype=np.float32)
  355. samples[0, : len(waveform)] = waveform
  356. else:
  357. # No padding for streaming or if padding_duration is None
  358. samples = waveform.reshape(1, -1).astype(np.float32)
  359. # Common input creation logic
  360. inputs = [
  361. protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
  362. protocol_client.InferInput(
  363. "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
  364. ),
  365. protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
  366. protocol_client.InferInput("target_text", [1, 1], "BYTES"),
  367. ]
  368. inputs[0].set_data_from_numpy(samples)
  369. inputs[1].set_data_from_numpy(lengths)
  370. input_data_numpy = np.array([reference_text], dtype=object)
  371. input_data_numpy = input_data_numpy.reshape((1, 1))
  372. inputs[2].set_data_from_numpy(input_data_numpy)
  373. input_data_numpy = np.array([target_text], dtype=object)
  374. input_data_numpy = input_data_numpy.reshape((1, 1))
  375. inputs[3].set_data_from_numpy(input_data_numpy)
  376. outputs = [protocol_client.InferRequestedOutput("waveform")]
  377. if use_spk2info_cache:
  378. inputs = inputs[-1:]
  379. return inputs, outputs
  380. def run_sync_streaming_inference(
  381. sync_triton_client: tritonclient.grpc.InferenceServerClient,
  382. model_name: str,
  383. inputs: list,
  384. outputs: list,
  385. request_id: str,
  386. user_data: UserData,
  387. chunk_overlap_duration: float,
  388. save_sample_rate: int,
  389. audio_save_path: str,
  390. ):
  391. """Helper function to run the blocking sync streaming call."""
  392. start_time_total = time.time()
  393. user_data.record_start_time() # Record start time for first chunk latency calculation
  394. # e.g. 08:47:34.827758
  395. print(f"Record start time in human readable: {datetime.now()}")
  396. # input()
  397. # Send request
  398. sync_triton_client.async_stream_infer(
  399. model_name,
  400. inputs,
  401. request_id=request_id,
  402. outputs=outputs,
  403. enable_empty_final_response=True,
  404. )
  405. # Process results
  406. audios = []
  407. while True:
  408. try:
  409. result = user_data._completed_requests.get(timeout=20) # Add timeout
  410. if isinstance(result, InferenceServerException):
  411. print(f"Received InferenceServerException: {result}")
  412. # Don't stop the stream here, just return error
  413. return None, None, None, None
  414. # Get response metadata
  415. response = result.get_response()
  416. final = response.parameters["triton_final_response"].bool_param
  417. if final is True:
  418. break
  419. audio_chunk = result.as_numpy("waveform").reshape(-1)
  420. if audio_chunk.size > 0: # Only append non-empty chunks
  421. audios.append(audio_chunk)
  422. else:
  423. print("Warning: received empty audio chunk.")
  424. except queue.Empty:
  425. print(f"Timeout waiting for response for request id {request_id}")
  426. # Don't stop stream here, just return error
  427. return None, None, None, None
  428. end_time_total = time.time()
  429. total_request_latency = end_time_total - start_time_total
  430. first_chunk_latency = user_data.get_first_chunk_latency()
  431. second_chunk_latency = user_data.get_second_chunk_latency()
  432. # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
  433. actual_duration = 0
  434. if audios:
  435. # Only spark_tts model uses cross-fade
  436. if model_name == "spark_tts":
  437. cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
  438. fade_out = np.linspace(1, 0, cross_fade_samples)
  439. fade_in = np.linspace(0, 1, cross_fade_samples)
  440. reconstructed_audio = None
  441. # Simplified reconstruction based on client_grpc_streaming.py
  442. if not audios:
  443. print("Warning: No audio chunks received.")
  444. reconstructed_audio = np.array([], dtype=np.float32) # Empty array
  445. elif len(audios) == 1:
  446. reconstructed_audio = audios[0]
  447. else:
  448. reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
  449. for i in range(1, len(audios)):
  450. # Cross-fade section
  451. cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
  452. audios[i - 1][-cross_fade_samples:] * fade_out)
  453. # Middle section of the current chunk
  454. middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
  455. # Concatenate
  456. reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
  457. # Add the last part of the final chunk
  458. reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
  459. if reconstructed_audio is not None and reconstructed_audio.size > 0:
  460. actual_duration = len(reconstructed_audio) / save_sample_rate
  461. # Save reconstructed audio
  462. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  463. else:
  464. print("Warning: No audio chunks received or reconstructed.")
  465. actual_duration = 0 # Set duration to 0 if no audio
  466. else:
  467. reconstructed_audio = np.concatenate(audios)
  468. print(f"reconstructed_audio: {reconstructed_audio.shape}")
  469. actual_duration = len(reconstructed_audio) / save_sample_rate
  470. # Save reconstructed audio
  471. sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
  472. else:
  473. print("Warning: No audio chunks received.")
  474. actual_duration = 0
  475. return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
  476. async def send_streaming(
  477. manifest_item_list: list,
  478. name: str,
  479. server_url: str, # Changed from sync_triton_client
  480. protocol_client: types.ModuleType,
  481. log_interval: int,
  482. model_name: str,
  483. audio_save_dir: str = "./",
  484. save_sample_rate: int = 16000,
  485. chunk_overlap_duration: float = 0.1,
  486. padding_duration: int = None,
  487. use_spk2info_cache: bool = False,
  488. ):
  489. total_duration = 0.0
  490. latency_data = []
  491. task_id = int(name[5:])
  492. sync_triton_client = None # Initialize client variable
  493. user_data_map = {}
  494. try: # Wrap in try...finally to ensure client closing
  495. print(f"{name}: Initializing sync client for streaming...")
  496. sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
  497. sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
  498. print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
  499. for i, item in enumerate(manifest_item_list):
  500. if i % log_interval == 0:
  501. print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
  502. try:
  503. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  504. reference_text, target_text = item["reference_text"], item["target_text"]
  505. inputs, outputs = prepare_request_input_output(
  506. protocol_client,
  507. waveform,
  508. reference_text,
  509. target_text,
  510. sample_rate,
  511. padding_duration=padding_duration,
  512. use_spk2info_cache=use_spk2info_cache
  513. )
  514. request_id = str(uuid.uuid4())
  515. user_data = UserData()
  516. user_data_map[request_id] = user_data
  517. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  518. print("target_text: ", target_text, "time: ", datetime.now())
  519. total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
  520. run_sync_streaming_inference,
  521. sync_triton_client,
  522. model_name,
  523. inputs,
  524. outputs,
  525. request_id,
  526. user_data,
  527. chunk_overlap_duration,
  528. save_sample_rate,
  529. audio_save_path
  530. )
  531. if total_request_latency is not None:
  532. print(
  533. f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
  534. f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
  535. f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
  536. )
  537. latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
  538. total_duration += actual_duration
  539. else:
  540. print(f"{name}: Item {i} failed.")
  541. del user_data_map[request_id]
  542. except FileNotFoundError:
  543. print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
  544. except Exception as e:
  545. print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
  546. import traceback
  547. traceback.print_exc()
  548. finally: # Ensure client is closed
  549. if sync_triton_client:
  550. try:
  551. print(f"{name}: Closing stream and sync client...")
  552. sync_triton_client.stop_stream()
  553. sync_triton_client.close()
  554. except Exception as e:
  555. print(f"{name}: Error closing sync client: {e}")
  556. print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
  557. return total_duration, latency_data
  558. async def send(
  559. manifest_item_list: list,
  560. name: str,
  561. triton_client: tritonclient.grpc.aio.InferenceServerClient,
  562. protocol_client: types.ModuleType,
  563. log_interval: int,
  564. model_name: str,
  565. padding_duration: int = None,
  566. audio_save_dir: str = "./",
  567. save_sample_rate: int = 16000,
  568. use_spk2info_cache: bool = False,
  569. ):
  570. total_duration = 0.0
  571. latency_data = []
  572. task_id = int(name[5:])
  573. print(f"manifest_item_list: {manifest_item_list}")
  574. for i, item in enumerate(manifest_item_list):
  575. if i % log_interval == 0:
  576. print(f"{name}: {i}/{len(manifest_item_list)}")
  577. waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
  578. reference_text, target_text = item["reference_text"], item["target_text"]
  579. inputs, outputs = prepare_request_input_output(
  580. protocol_client,
  581. waveform,
  582. reference_text,
  583. target_text,
  584. sample_rate,
  585. padding_duration=padding_duration,
  586. use_spk2info_cache=use_spk2info_cache
  587. )
  588. sequence_id = 100000000 + i + task_id * 10
  589. start = time.time()
  590. response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
  591. audio = response.as_numpy("waveform").reshape(-1)
  592. actual_duration = len(audio) / save_sample_rate
  593. end = time.time() - start
  594. audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
  595. sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
  596. latency_data.append((end, actual_duration))
  597. total_duration += actual_duration
  598. return total_duration, latency_data
  599. def load_manifests(manifest_path):
  600. with open(manifest_path, "r") as f:
  601. manifest_list = []
  602. for line in f:
  603. assert len(line.strip().split("|")) == 4
  604. utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
  605. utt = Path(utt).stem
  606. # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
  607. if not os.path.isabs(prompt_wav):
  608. prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
  609. manifest_list.append(
  610. {
  611. "audio_filepath": prompt_wav,
  612. "reference_text": prompt_text,
  613. "target_text": gt_text,
  614. "target_audio_path": utt,
  615. }
  616. )
  617. return manifest_list
  618. def split_data(data, k):
  619. n = len(data)
  620. if n < k:
  621. print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
  622. k = n
  623. quotient = n // k
  624. remainder = n % k
  625. result = []
  626. start = 0
  627. for i in range(k):
  628. if i < remainder:
  629. end = start + quotient + 1
  630. else:
  631. end = start + quotient
  632. result.append(data[start:end])
  633. start = end
  634. return result
  635. async def main():
  636. args = get_args()
  637. url = f"{args.server_addr}:{args.server_port}"
  638. # --- Client Initialization based on mode ---
  639. triton_client = None
  640. protocol_client = None
  641. if args.mode == "offline":
  642. print("Initializing gRPC client for offline mode...")
  643. # Use the async client for offline tasks
  644. triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  645. protocol_client = grpcclient_aio
  646. elif args.mode == "streaming":
  647. print("Initializing gRPC client for streaming mode...")
  648. # Use the sync client for streaming tasks, handled via asyncio.to_thread
  649. # We will create one sync client instance PER TASK inside send_streaming.
  650. # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
  651. protocol_client = grpcclient_sync # protocol client for input prep
  652. else:
  653. raise ValueError(f"Invalid mode: {args.mode}")
  654. # --- End Client Initialization ---
  655. if args.reference_audio:
  656. args.num_tasks = 1
  657. args.log_interval = 1
  658. manifest_item_list = [
  659. {
  660. "reference_text": args.reference_text,
  661. "target_text": args.target_text,
  662. "audio_filepath": args.reference_audio,
  663. "target_audio_path": "test",
  664. }
  665. ]
  666. elif args.huggingface_dataset:
  667. import datasets
  668. dataset = datasets.load_dataset(
  669. args.huggingface_dataset,
  670. split=args.split_name,
  671. trust_remote_code=True,
  672. )
  673. manifest_item_list = []
  674. tmp_audio_path="./asset_zero_shot_prompt.wav"
  675. tmp_audio_text="希望你以后能够做的比我还好呦。"
  676. for i in range(len(dataset)):
  677. manifest_item_list.append(
  678. {
  679. "audio_filepath": dataset[i]["prompt_audio"],
  680. "reference_text": dataset[i]["prompt_text"],
  681. # "audio_filepath": tmp_audio_path,
  682. # "reference_text": tmp_audio_text,
  683. "target_audio_path": dataset[i]["id"],
  684. "target_text": dataset[i]["target_text"],
  685. }
  686. )
  687. # manifest_item_list = manifest_item_list[:4]
  688. else:
  689. manifest_item_list = load_manifests(args.manifest_path)
  690. # --- Statistics Fetching (Before) ---
  691. stats_client = None
  692. stats_before = None
  693. try:
  694. print("Initializing temporary async client for fetching stats...")
  695. stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
  696. print("Fetching inference statistics before running tasks...")
  697. stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
  698. except Exception as e:
  699. print(f"Could not retrieve statistics before running tasks: {e}")
  700. # --- End Statistics Fetching (Before) ---
  701. num_tasks = min(args.num_tasks, len(manifest_item_list))
  702. manifest_item_list = split_data(manifest_item_list, num_tasks)
  703. os.makedirs(args.log_dir, exist_ok=True)
  704. args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
  705. tasks = []
  706. start_time = time.time()
  707. for i in range(num_tasks):
  708. # --- Task Creation based on mode ---
  709. if args.mode == "offline":
  710. task = asyncio.create_task(
  711. send(
  712. manifest_item_list[i],
  713. name=f"task-{i}",
  714. triton_client=triton_client,
  715. protocol_client=protocol_client,
  716. log_interval=args.log_interval,
  717. model_name=args.model_name,
  718. audio_save_dir=args.log_dir,
  719. padding_duration=1,
  720. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  721. use_spk2info_cache=args.use_spk2info_cache,
  722. )
  723. )
  724. elif args.mode == "streaming":
  725. task = asyncio.create_task(
  726. send_streaming(
  727. manifest_item_list[i],
  728. name=f"task-{i}",
  729. server_url=url, # Pass URL instead of client
  730. protocol_client=protocol_client,
  731. log_interval=args.log_interval,
  732. model_name=args.model_name,
  733. audio_save_dir=args.log_dir,
  734. padding_duration=10,
  735. save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
  736. chunk_overlap_duration=args.chunk_overlap_duration,
  737. use_spk2info_cache=args.use_spk2info_cache,
  738. )
  739. )
  740. # --- End Task Creation ---
  741. tasks.append(task)
  742. ans_list = await asyncio.gather(*tasks)
  743. end_time = time.time()
  744. elapsed = end_time - start_time
  745. total_duration = 0.0
  746. latency_data = []
  747. for ans in ans_list:
  748. if ans:
  749. total_duration += ans[0]
  750. latency_data.extend(ans[1]) # Use extend for list of lists
  751. else:
  752. print("Warning: A task returned None, possibly due to an error.")
  753. if total_duration == 0:
  754. print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
  755. rtf = float('inf')
  756. else:
  757. rtf = elapsed / total_duration
  758. s = f"Mode: {args.mode}\n"
  759. s += f"RTF: {rtf:.4f}\n"
  760. s += f"total_duration: {total_duration:.3f} seconds\n"
  761. s += f"({total_duration / 3600:.2f} hours)\n"
  762. s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
  763. # --- Statistics Reporting based on mode ---
  764. if latency_data:
  765. if args.mode == "offline":
  766. # Original offline latency calculation
  767. latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
  768. if latency_list:
  769. latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
  770. latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
  771. s += f"latency_variance: {latency_variance:.2f}\n"
  772. s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
  773. s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
  774. s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
  775. s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
  776. s += f"average_latency_ms: {latency_ms:.2f}\n"
  777. else:
  778. s += "No latency data collected for offline mode.\n"
  779. elif args.mode == "streaming":
  780. # Calculate stats for total request latency and first chunk latency
  781. total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
  782. first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
  783. second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
  784. s += "\n--- Total Request Latency ---\n"
  785. if total_latency_list:
  786. avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
  787. variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
  788. s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
  789. s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
  790. s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
  791. s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
  792. s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
  793. s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
  794. else:
  795. s += "No total request latency data collected.\n"
  796. s += "\n--- First Chunk Latency ---\n"
  797. if first_chunk_latency_list:
  798. avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
  799. variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
  800. s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
  801. s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
  802. s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
  803. s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
  804. s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
  805. s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
  806. else:
  807. s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
  808. s += "\n--- Second Chunk Latency ---\n"
  809. if second_chunk_latency_list:
  810. avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
  811. variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
  812. s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
  813. s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
  814. s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
  815. s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
  816. s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
  817. s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
  818. else:
  819. s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
  820. else:
  821. s += "No latency data collected.\n"
  822. # --- End Statistics Reporting ---
  823. print(s)
  824. if args.manifest_path:
  825. name = Path(args.manifest_path).stem
  826. elif args.split_name:
  827. name = args.split_name
  828. elif args.reference_audio:
  829. name = Path(args.reference_audio).stem
  830. else:
  831. name = "results" # Default name if no manifest/split/audio provided
  832. with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
  833. f.write(s)
  834. # --- Statistics Fetching using temporary Async Client ---
  835. # Use a separate async client for fetching stats regardless of mode
  836. try:
  837. if stats_client and stats_before:
  838. print("Fetching inference statistics after running tasks...")
  839. stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
  840. print("Calculating statistics difference...")
  841. stats = subtract_stats(stats_after, stats_before)
  842. print("Fetching model config...")
  843. metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
  844. write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
  845. with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
  846. json.dump(metadata, f, indent=4)
  847. else:
  848. print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
  849. except Exception as e:
  850. print(f"Could not retrieve statistics or config: {e}")
  851. finally:
  852. if stats_client:
  853. try:
  854. print("Closing temporary async stats client...")
  855. await stats_client.close()
  856. except Exception as e:
  857. print(f"Error closing async stats client: {e}")
  858. # --- End Statistics Fetching ---
  859. if __name__ == "__main__":
  860. # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
  861. async def run_main():
  862. try:
  863. await main()
  864. except Exception as e:
  865. print(f"An error occurred in main: {e}")
  866. import traceback
  867. traceback.print_exc()
  868. asyncio.run(run_main())