offline_inference.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Example Usage
  16. CUDA_VISIBLE_DEVICES=0 \
  17. python3 offline_inference.py \
  18. --output-dir $output_dir \
  19. --llm-model-name-or-path $huggingface_model_local_dir \
  20. --token2wav-path $model_scope_model_local_dir \
  21. --backend $backend \
  22. --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
  23. --engine-dir $trt_engines_dir \
  24. --split-name ${dataset} || exit 1
  25. """
  26. import argparse
  27. import json
  28. import os
  29. import sys
  30. from pathlib import Path
  31. import torch
  32. import torch.distributed as dist
  33. import torch.nn.functional as F
  34. import torchaudio
  35. from cosyvoice.utils.file_utils import load_wav
  36. from datasets import load_dataset
  37. from transformers import AutoTokenizer
  38. from torch.utils.data import DataLoader, Dataset
  39. from tqdm import tqdm
  40. import soundfile as sf
  41. import s3tokenizer
  42. from functools import partial
  43. import time
  44. import requests
  45. import asyncio
  46. import httpx
  47. sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
  48. try:
  49. torch.multiprocessing.set_start_method("spawn")
  50. except RuntimeError:
  51. pass
  52. async def send_request_async(client, url, payload):
  53. response = await client.post(url, json=payload, timeout=None)
  54. response.raise_for_status()
  55. response_json = response.json()
  56. return response_json['choices'][0]['message']['content']
  57. async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
  58. async with httpx.AsyncClient() as client:
  59. tasks = []
  60. for chat in chats:
  61. payload = {
  62. "model": model_name,
  63. "messages": chat,
  64. "max_tokens": 2048,
  65. "temperature": temperature,
  66. "top_p": top_p,
  67. "top_k": top_k,
  68. "repetition_penalty": 1.1,
  69. "stop": ["<|eos1|>", "<|eos|>"],
  70. "stream": False,
  71. }
  72. tasks.append(send_request_async(client, api_base, payload))
  73. return await asyncio.gather(*tasks)
  74. def extract_speech_ids(speech_tokens_str):
  75. """Extract speech IDs from token strings like <|s_23456|>"""
  76. speech_ids = []
  77. for token_str in speech_tokens_str:
  78. if token_str.startswith('<|s_') and token_str.endswith('|>'):
  79. num_str = token_str[4:-2]
  80. num = int(num_str)
  81. speech_ids.append(num)
  82. else:
  83. print(f"Unexpected token: {token_str}")
  84. return speech_ids
  85. def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
  86. """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
  87. speech_id_str = ""
  88. for token in cosy2_tokens:
  89. speech_id_str += f"<|s_{token}|>"
  90. return speech_id_str
  91. def get_args():
  92. parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
  93. parser.add_argument(
  94. "--split-name",
  95. type=str,
  96. default="wenetspeech4tts",
  97. help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
  98. )
  99. parser.add_argument(
  100. "--output-dir", required=True, type=str, help="dir to save result"
  101. )
  102. parser.add_argument(
  103. "--batch-size",
  104. default=1,
  105. type=int,
  106. help="batch size (per-device) for inference",
  107. )
  108. parser.add_argument(
  109. "--token2wav-batch-size",
  110. default=1,
  111. type=int,
  112. help="batch size (per-device) for inference",
  113. )
  114. parser.add_argument(
  115. "--num-workers", type=int, default=0, help="workers for dataloader"
  116. )
  117. parser.add_argument(
  118. "--prefetch", type=int, default=None, help="prefetch for dataloader"
  119. )
  120. parser.add_argument(
  121. "--llm-model-name-or-path",
  122. required=True,
  123. type=str,
  124. help="LLM model path (includes both model and tokenizer)",
  125. )
  126. parser.add_argument(
  127. "--token2wav-path",
  128. required=True,
  129. type=str,
  130. help="CosyVoice2 token2wav model path",
  131. )
  132. parser.add_argument(
  133. "--prompt-text",
  134. type=str,
  135. default=None,
  136. help="The prompt text for CosyVoice2",
  137. )
  138. parser.add_argument(
  139. "--prompt-speech-path",
  140. type=str,
  141. default=None,
  142. help="The path to the prompt speech for CosyVoice2",
  143. )
  144. parser.add_argument(
  145. "--top-p",
  146. type=float,
  147. default=0.95,
  148. help="top p for sampling",
  149. )
  150. parser.add_argument(
  151. "--temperature",
  152. type=float,
  153. default=0.8,
  154. help="temperature for sampling",
  155. )
  156. parser.add_argument(
  157. "--top-k",
  158. type=int,
  159. default=50,
  160. help="top k for sampling",
  161. )
  162. parser.add_argument(
  163. "--backend",
  164. type=str,
  165. default="hf",
  166. choices=["hf", "trtllm", "vllm", "trtllm-serve"],
  167. help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
  168. )
  169. parser.add_argument(
  170. "--engine-dir",
  171. type=str,
  172. default=None,
  173. help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
  174. )
  175. parser.add_argument(
  176. "--kv-cache-free-gpu-memory-fraction",
  177. type=float,
  178. default=0.6,
  179. help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
  180. )
  181. parser.add_argument(
  182. "--openai-api-base",
  183. type=str,
  184. default="http://localhost:8000/v1/chat/completions",
  185. help="OpenAI API base URL (for trtllm-serve backend)",
  186. )
  187. parser.add_argument(
  188. "--openai-model-name",
  189. type=str,
  190. default="trt_engines_bfloat16",
  191. help="Model name to use with OpenAI API (for trtllm-serve backend)",
  192. )
  193. args = parser.parse_args()
  194. return args
  195. def data_collator(batch, tokenizer, s3_tokenizer):
  196. """Simplified data collator for batch_size=1 processing"""
  197. collator_start_time = time.time()
  198. total_audio_processing_time = 0
  199. total_speech_tokenization_time = 0
  200. total_text_tokenization_time = 0
  201. target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
  202. device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
  203. input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
  204. prompt_text_after_apply_template_list = []
  205. mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
  206. chat_list = []
  207. for _, item in enumerate(batch):
  208. audio_processing_start_time = time.time()
  209. prompt_text, target_text = (
  210. item["prompt_text"],
  211. item["target_text"],
  212. )
  213. prompt_text_list.append(prompt_text)
  214. full_text = prompt_text + target_text
  215. full_text_list.append(full_text)
  216. # remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
  217. puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
  218. for p in puncts:
  219. if p in full_text:
  220. full_text = full_text.replace(p, '')
  221. print(f"removed {p} from {full_text}")
  222. # get prompt audio for CosyVoice2 (convert to 16kHz)
  223. ref_audio_org, ref_sr = (
  224. item["prompt_audio"]["array"],
  225. item["prompt_audio"]["sampling_rate"],
  226. )
  227. ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
  228. print(ref_audio_org.shape)
  229. if ref_sr != target_sample_rate:
  230. resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
  231. ref_audio = resampler(ref_audio_org)
  232. else:
  233. ref_audio = ref_audio_org
  234. prompt_audio_list.append(ref_audio)
  235. audio_processing_end_time = time.time()
  236. total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
  237. speech_tokenization_start_time = time.time()
  238. if "prompt_audio_cosy2_tokens" in item:
  239. prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
  240. prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
  241. else:
  242. mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
  243. if len(mels) > 0:
  244. mels, mels_lens = s3tokenizer.padding(mels)
  245. codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
  246. for i in range(len(codes)):
  247. prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
  248. speech_tokenization_end_time = time.time()
  249. total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
  250. for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
  251. text_tokenization_start_time = time.time()
  252. prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
  253. # Create chat template for LLM generation
  254. chat = [
  255. {"role": "user", "content": full_text_list[i]},
  256. {"role": "assistant", "content": prompt_audio_cosy2_id_str}
  257. ]
  258. chat_list.append(chat)
  259. assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
  260. input_ids = tokenizer.apply_chat_template(
  261. chat,
  262. tokenize=True,
  263. return_tensors='pt',
  264. continue_final_message=True
  265. )
  266. input_ids_list.append(input_ids.squeeze(0))
  267. prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
  268. prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
  269. text_tokenization_end_time = time.time()
  270. total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
  271. ids = [item["id"] for item in batch]
  272. return {
  273. "input_ids": input_ids_list,
  274. "ids": ids,
  275. "prompt_text": prompt_text_list,
  276. "prompt_audio_list": prompt_audio_list,
  277. "prompt_text_after_apply_template": prompt_text_after_apply_template_list,
  278. "audio_processing_time": total_audio_processing_time,
  279. "speech_tokenization_time": total_speech_tokenization_time,
  280. "text_tokenization_time": total_text_tokenization_time,
  281. "chat_list": chat_list
  282. }
  283. def init_distributed():
  284. world_size = int(os.environ.get("WORLD_SIZE", 1))
  285. local_rank = int(os.environ.get("LOCAL_RANK", 0))
  286. rank = int(os.environ.get("RANK", 0))
  287. print(
  288. "Inference on multiple gpus, this gpu {}".format(local_rank)
  289. + ", rank {}, world_size {}".format(rank, world_size)
  290. )
  291. torch.cuda.set_device(local_rank)
  292. dist.init_process_group("nccl")
  293. return world_size, local_rank, rank
  294. def main(args):
  295. os.makedirs(args.output_dir, exist_ok=True)
  296. assert torch.cuda.is_available()
  297. local_rank, world_size, rank = 0, 1, 0
  298. device = torch.device(f"cuda:{local_rank}")
  299. tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
  300. if args.backend == "hf":
  301. model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
  302. model.eval()
  303. model.to(device)
  304. runner = None
  305. elif args.backend == "trtllm":
  306. if args.engine_dir is None:
  307. raise ValueError("--engine-dir is required when backend is 'trtllm'")
  308. runtime_rank = tensorrt_llm.mpi_rank()
  309. model = None
  310. runner_kwargs = dict(
  311. engine_dir=args.engine_dir,
  312. rank=runtime_rank,
  313. max_output_len=2048,
  314. enable_context_fmha_fp32_acc=False,
  315. max_batch_size=args.batch_size,
  316. max_input_len=512,
  317. kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
  318. cuda_graph_mode=False,
  319. gather_generation_logits=False,
  320. )
  321. runner = ModelRunnerCpp.from_dir(**runner_kwargs)
  322. elif args.backend == "vllm":
  323. model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
  324. runner = None
  325. elif args.backend == "trtllm-serve":
  326. model = None
  327. runner = None
  328. else:
  329. raise ValueError(f"Unsupported backend: {args.backend}")
  330. if 'Step-Audio-2-mini' in args.token2wav_path:
  331. from token2wav_dit import CosyVoice2_Token2Wav
  332. else:
  333. assert 'CosyVoice2-0.5B' in args.token2wav_path
  334. from token2wav import CosyVoice2_Token2Wav
  335. token2wav_model = CosyVoice2_Token2Wav(
  336. model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
  337. )
  338. if args.prompt_speech_path:
  339. prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
  340. else:
  341. prompt_speech_16k = None
  342. s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
  343. dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
  344. dataset = load_dataset(
  345. dataset_name,
  346. split=args.split_name,
  347. trust_remote_code=True,
  348. )
  349. sampler = None
  350. dataloader = DataLoader(
  351. dataset,
  352. batch_size=args.batch_size,
  353. sampler=sampler,
  354. shuffle=False,
  355. num_workers=args.num_workers,
  356. prefetch_factor=args.prefetch,
  357. collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
  358. )
  359. for _ in range(3):
  360. print(f"Running {_} times")
  361. total_llm_time = 0
  362. total_token2wav_time = 0
  363. total_data_load_time = 0
  364. total_llm_post_processing_time = 0
  365. total_audio_save_time = 0
  366. total_audio_processing_time_in_collator = 0
  367. total_speech_tokenization_time_in_collator = 0
  368. total_text_tokenization_time_in_collator = 0
  369. total_audio_samples = 0
  370. start_time = time.time()
  371. total_steps = len(dataset)
  372. if rank == 0:
  373. progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
  374. last_batch_end_time = time.time()
  375. for batch in dataloader:
  376. data_loaded_time = time.time()
  377. total_data_load_time += data_loaded_time - last_batch_end_time
  378. total_audio_processing_time_in_collator += batch["audio_processing_time"]
  379. total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
  380. total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
  381. with torch.no_grad():
  382. llm_start_time = time.time()
  383. if args.backend == "hf":
  384. input_ids_list = batch["input_ids"]
  385. if len(input_ids_list) == 1:
  386. input_ids = input_ids_list[0].unsqueeze(0)
  387. attention_mask = torch.ones_like(input_ids)
  388. else:
  389. max_len = max([len(input_ids) for input_ids in input_ids_list])
  390. input_ids_list_new = [
  391. torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
  392. for input_ids in input_ids_list
  393. ]
  394. input_ids = torch.stack(input_ids_list_new)
  395. attention_mask = torch.zeros_like(input_ids)
  396. for i in range(len(input_ids_list)):
  397. attention_mask[i, :len(input_ids_list[i])] = 1
  398. input_ids = input_ids.to(device)
  399. outputs = model.generate(
  400. input_ids=input_ids.to(device),
  401. attention_mask=attention_mask.to(device),
  402. max_new_tokens=2048,
  403. do_sample=True,
  404. top_p=args.top_p,
  405. temperature=args.temperature,
  406. repetition_penalty=1.1,
  407. top_k=args.top_k,
  408. )
  409. torch.cuda.synchronize()
  410. elif args.backend == "trtllm":
  411. batch_input_ids = list(batch["input_ids"])
  412. input_lengths = [x.size(0) for x in batch_input_ids]
  413. end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
  414. print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
  415. outputs = runner.generate(
  416. batch_input_ids=batch_input_ids,
  417. max_new_tokens=2048,
  418. end_id=end_id,
  419. pad_id=end_id,
  420. temperature=args.temperature,
  421. top_k=args.top_k,
  422. top_p=args.top_p,
  423. repetition_penalty=1.1,
  424. num_return_sequences=1,
  425. streaming=False,
  426. output_sequence_lengths=True,
  427. output_generation_logits=False,
  428. return_dict=True,
  429. return_all_generated_tokens=False
  430. )
  431. torch.cuda.synchronize()
  432. output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
  433. num_output_sents, num_beams, _ = output_ids.size()
  434. assert num_beams == 1
  435. beam = 0
  436. batch_size = len(batch["input_ids"])
  437. num_return_sequences = num_output_sents // batch_size
  438. assert num_return_sequences == 1
  439. outputs = []
  440. for i in range(batch_size * num_return_sequences):
  441. batch_idx = i // num_return_sequences
  442. seq_idx = i % num_return_sequences
  443. output_begin = input_lengths[batch_idx]
  444. output_end = sequence_lengths[i][beam]
  445. outputs_i = output_ids[i][beam][:output_end].tolist()
  446. outputs.append(outputs_i)
  447. elif args.backend == "vllm":
  448. input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
  449. sampling_params = SamplingParams(
  450. temperature=args.temperature,
  451. top_p=args.top_p,
  452. top_k=args.top_k,
  453. repetition_penalty=1.1,
  454. max_tokens=2048,
  455. )
  456. outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
  457. print(outputs)
  458. for j, output in enumerate(outputs):
  459. outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
  460. elif args.backend == "trtllm-serve":
  461. if args.batch_size > 1:
  462. outputs = asyncio.run(send_batch_requests_async(
  463. args.openai_api_base,
  464. args.openai_model_name,
  465. batch["chat_list"],
  466. args.temperature,
  467. args.top_p,
  468. args.top_k,
  469. ))
  470. else:
  471. outputs = []
  472. for i, chat in enumerate(batch["chat_list"]):
  473. payload = {
  474. "model": args.openai_model_name,
  475. "messages": chat,
  476. "max_tokens": 2048,
  477. "temperature": args.temperature,
  478. "top_p": args.top_p,
  479. "top_k": args.top_k,
  480. "repetition_penalty": 1.1,
  481. "stop": ["<|eos1|>", "<|eos|>"],
  482. "stream": False,
  483. }
  484. response = requests.post(args.openai_api_base, json=payload)
  485. response.raise_for_status()
  486. response_json = response.json()
  487. generated_content = response_json['choices'][0]['message']['content']
  488. outputs.append(generated_content)
  489. llm_end_time = time.time()
  490. total_llm_time += (llm_end_time - llm_start_time)
  491. items_for_token_2wav = []
  492. for i in range(len(batch["ids"])):
  493. llm_post_processing_start_time = time.time()
  494. if args.backend == "trtllm-serve":
  495. speech_tokens_str = outputs[i].strip().split('><')
  496. if len(speech_tokens_str) > 1:
  497. speech_tokens_str = [
  498. t if t.startswith('<') else '<' + t for t in speech_tokens_str
  499. ]
  500. speech_tokens_str = [
  501. t if t.endswith('>') else t + '>' for t in speech_tokens_str
  502. ]
  503. speech_ids = extract_speech_ids(speech_tokens_str)
  504. else:
  505. input_length = len(batch["input_ids"][i])
  506. generated_ids = outputs[i][input_length:]
  507. speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  508. speech_ids = extract_speech_ids(speech_tokens_str)
  509. print(i, speech_ids)
  510. if len(speech_ids) == 0:
  511. print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
  512. continue
  513. if args.prompt_text is not None:
  514. current_prompt_text = args.prompt_text
  515. current_prompt_audio = prompt_speech_16k
  516. else:
  517. current_prompt_text = batch["prompt_text"][i]
  518. current_prompt_audio = batch["prompt_audio_list"][i]
  519. llm_post_processing_end_time = time.time()
  520. total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
  521. if current_prompt_audio is not None:
  522. items_for_token_2wav.append({
  523. "speech_ids": speech_ids,
  524. "prompt_audio": current_prompt_audio.squeeze(0),
  525. "id": batch["ids"][i]
  526. })
  527. else:
  528. print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
  529. for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
  530. t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
  531. if not t2w_batch:
  532. continue
  533. t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
  534. t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
  535. t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
  536. t2w_ids = [item["id"] for item in t2w_batch]
  537. token2wav_start_time = time.time()
  538. generated_wavs = token2wav_model(
  539. t2w_generated_speech_tokens_list,
  540. t2w_prompt_audios_list,
  541. t2w_prompt_audios_sample_rate,
  542. )
  543. token2wav_end_time = time.time()
  544. total_token2wav_time += (token2wav_end_time - token2wav_start_time)
  545. audio_save_start_time = time.time()
  546. for j, audio_hat in enumerate(generated_wavs):
  547. generated_wave = audio_hat.squeeze().cpu().numpy()
  548. total_audio_samples += len(generated_wave)
  549. target_sample_rate = 24000
  550. utt = t2w_ids[j]
  551. sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
  552. print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
  553. audio_save_end_time = time.time()
  554. total_audio_save_time += audio_save_end_time - audio_save_start_time
  555. if rank == 0:
  556. progress_bar.update(world_size * len(batch["ids"]))
  557. last_batch_end_time = time.time()
  558. if rank == 0:
  559. progress_bar.close()
  560. end_time = time.time()
  561. target_sample_rate = 24000
  562. total_audio_duration_seconds = total_audio_samples / target_sample_rate
  563. log_file_path = os.path.join(args.output_dir, "log.txt")
  564. with open(log_file_path, 'w') as f:
  565. args_dict = vars(args)
  566. log_data = {
  567. "args": args_dict,
  568. "data_load_time_seconds": total_data_load_time,
  569. "audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
  570. "speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
  571. "text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
  572. "llm_time_seconds": total_llm_time,
  573. "llm_post_processing_time_seconds": total_llm_post_processing_time,
  574. "token2wav_time_seconds": total_token2wav_time,
  575. "audio_save_time_seconds": total_audio_save_time,
  576. "total_audio_duration_seconds": total_audio_duration_seconds,
  577. "pipeline_time_seconds": end_time - start_time,
  578. }
  579. print(log_data)
  580. f.write(json.dumps(log_data, indent=4))
  581. print(f"Metrics logged to {log_file_path}")
  582. if __name__ == "__main__":
  583. args = get_args()
  584. if args.backend == "vllm":
  585. from vllm import LLM, SamplingParams
  586. elif args.backend == "trtllm":
  587. import tensorrt_llm
  588. from tensorrt_llm.runtime import ModelRunnerCpp
  589. elif args.backend == "hf":
  590. from transformers import AutoModelForCausalLM
  591. elif args.backend == "trtllm-serve":
  592. pass
  593. else:
  594. raise ValueError(f"Unsupported backend: {args.backend}")
  595. main(args)