offline_inference.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Example Usage
  16. CUDA_VISIBLE_DEVICES=0 \
  17. python3 offline_inference.py \
  18. --output-dir $output_dir \
  19. --llm-model-name-or-path $huggingface_model_local_dir \
  20. --token2wav-path $model_scope_model_local_dir \
  21. --backend $backend \
  22. --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
  23. --engine-dir $trt_engines_dir \
  24. --split-name ${dataset} || exit 1
  25. """
  26. import argparse
  27. import json
  28. import os
  29. import sys
  30. from pathlib import Path
  31. import torch
  32. import torch.distributed as dist
  33. import torch.nn.functional as F
  34. import torchaudio
  35. from cosyvoice.utils.file_utils import load_wav
  36. from datasets import load_dataset
  37. from transformers import AutoTokenizer
  38. from torch.utils.data import DataLoader, Dataset
  39. from tqdm import tqdm
  40. import soundfile as sf
  41. import s3tokenizer
  42. from functools import partial
  43. import time
  44. import requests
  45. import asyncio
  46. import httpx
  47. from token2wav import CosyVoice2_Token2Wav
  48. sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
  49. try:
  50. torch.multiprocessing.set_start_method("spawn")
  51. except RuntimeError:
  52. pass
  53. async def send_request_async(client, url, payload):
  54. response = await client.post(url, json=payload, timeout=None)
  55. response.raise_for_status()
  56. response_json = response.json()
  57. return response_json['choices'][0]['message']['content']
  58. async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
  59. async with httpx.AsyncClient() as client:
  60. tasks = []
  61. for chat in chats:
  62. payload = {
  63. "model": model_name,
  64. "messages": chat,
  65. "max_tokens": 2048,
  66. "temperature": temperature,
  67. "top_p": top_p,
  68. "top_k": top_k,
  69. "repetition_penalty": 1.1,
  70. "stop": ["<|eos1|>", "<|eos|>"],
  71. "stream": False,
  72. }
  73. tasks.append(send_request_async(client, api_base, payload))
  74. return await asyncio.gather(*tasks)
  75. def extract_speech_ids(speech_tokens_str):
  76. """Extract speech IDs from token strings like <|s_23456|>"""
  77. speech_ids = []
  78. for token_str in speech_tokens_str:
  79. if token_str.startswith('<|s_') and token_str.endswith('|>'):
  80. num_str = token_str[4:-2]
  81. num = int(num_str)
  82. speech_ids.append(num)
  83. else:
  84. print(f"Unexpected token: {token_str}")
  85. return speech_ids
  86. def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
  87. """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
  88. speech_id_str = ""
  89. for token in cosy2_tokens:
  90. speech_id_str += f"<|s_{token}|>"
  91. return speech_id_str
  92. def get_args():
  93. parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
  94. parser.add_argument(
  95. "--split-name",
  96. type=str,
  97. default="wenetspeech4tts",
  98. help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
  99. )
  100. parser.add_argument(
  101. "--output-dir", required=True, type=str, help="dir to save result"
  102. )
  103. parser.add_argument(
  104. "--batch-size",
  105. default=1,
  106. type=int,
  107. help="batch size (per-device) for inference",
  108. )
  109. parser.add_argument(
  110. "--token2wav-batch-size",
  111. default=1,
  112. type=int,
  113. help="batch size (per-device) for inference",
  114. )
  115. parser.add_argument(
  116. "--num-workers", type=int, default=0, help="workers for dataloader"
  117. )
  118. parser.add_argument(
  119. "--prefetch", type=int, default=None, help="prefetch for dataloader"
  120. )
  121. parser.add_argument(
  122. "--llm-model-name-or-path",
  123. required=True,
  124. type=str,
  125. help="LLM model path (includes both model and tokenizer)",
  126. )
  127. parser.add_argument(
  128. "--token2wav-path",
  129. required=True,
  130. type=str,
  131. help="CosyVoice2 token2wav model path",
  132. )
  133. parser.add_argument(
  134. "--prompt-text",
  135. type=str,
  136. default=None,
  137. help="The prompt text for CosyVoice2",
  138. )
  139. parser.add_argument(
  140. "--prompt-speech-path",
  141. type=str,
  142. default=None,
  143. help="The path to the prompt speech for CosyVoice2",
  144. )
  145. parser.add_argument(
  146. "--top-p",
  147. type=float,
  148. default=0.95,
  149. help="top p for sampling",
  150. )
  151. parser.add_argument(
  152. "--temperature",
  153. type=float,
  154. default=0.8,
  155. help="temperature for sampling",
  156. )
  157. parser.add_argument(
  158. "--top-k",
  159. type=int,
  160. default=50,
  161. help="top k for sampling",
  162. )
  163. parser.add_argument(
  164. "--backend",
  165. type=str,
  166. default="hf",
  167. choices=["hf", "trtllm", "vllm", "trtllm-serve"],
  168. help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
  169. )
  170. parser.add_argument(
  171. "--engine-dir",
  172. type=str,
  173. default=None,
  174. help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
  175. )
  176. parser.add_argument(
  177. "--kv-cache-free-gpu-memory-fraction",
  178. type=float,
  179. default=0.6,
  180. help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
  181. )
  182. parser.add_argument(
  183. "--openai-api-base",
  184. type=str,
  185. default="http://localhost:8000/v1/chat/completions",
  186. help="OpenAI API base URL (for trtllm-serve backend)",
  187. )
  188. parser.add_argument(
  189. "--openai-model-name",
  190. type=str,
  191. default="trt_engines_bfloat16",
  192. help="Model name to use with OpenAI API (for trtllm-serve backend)",
  193. )
  194. args = parser.parse_args()
  195. return args
  196. def data_collator(batch, tokenizer, s3_tokenizer):
  197. """Simplified data collator for batch_size=1 processing"""
  198. collator_start_time = time.time()
  199. total_audio_processing_time = 0
  200. total_speech_tokenization_time = 0
  201. total_text_tokenization_time = 0
  202. target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
  203. device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
  204. input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
  205. prompt_text_after_apply_template_list = []
  206. mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
  207. chat_list = []
  208. for _, item in enumerate(batch):
  209. audio_processing_start_time = time.time()
  210. prompt_text, target_text = (
  211. item["prompt_text"],
  212. item["target_text"],
  213. )
  214. prompt_text_list.append(prompt_text)
  215. full_text = prompt_text + target_text
  216. full_text_list.append(full_text)
  217. # remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
  218. puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
  219. for p in puncts:
  220. if p in full_text:
  221. full_text = full_text.replace(p, '')
  222. print(f"removed {p} from {full_text}")
  223. # get prompt audio for CosyVoice2 (convert to 16kHz)
  224. ref_audio_org, ref_sr = (
  225. item["prompt_audio"]["array"],
  226. item["prompt_audio"]["sampling_rate"],
  227. )
  228. ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
  229. print(ref_audio_org.shape)
  230. if ref_sr != target_sample_rate:
  231. resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
  232. ref_audio = resampler(ref_audio_org)
  233. else:
  234. ref_audio = ref_audio_org
  235. prompt_audio_list.append(ref_audio)
  236. audio_processing_end_time = time.time()
  237. total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
  238. speech_tokenization_start_time = time.time()
  239. if "prompt_audio_cosy2_tokens" in item:
  240. prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
  241. prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
  242. else:
  243. mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
  244. if len(mels) > 0:
  245. mels, mels_lens = s3tokenizer.padding(mels)
  246. codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
  247. for i in range(len(codes)):
  248. prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
  249. speech_tokenization_end_time = time.time()
  250. total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
  251. for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
  252. text_tokenization_start_time = time.time()
  253. prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
  254. # Create chat template for LLM generation
  255. chat = [
  256. {"role": "user", "content": full_text_list[i]},
  257. {"role": "assistant", "content": prompt_audio_cosy2_id_str}
  258. ]
  259. chat_list.append(chat)
  260. assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
  261. input_ids = tokenizer.apply_chat_template(
  262. chat,
  263. tokenize=True,
  264. return_tensors='pt',
  265. continue_final_message=True
  266. )
  267. input_ids_list.append(input_ids.squeeze(0))
  268. prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
  269. prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
  270. text_tokenization_end_time = time.time()
  271. total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
  272. ids = [item["id"] for item in batch]
  273. return {
  274. "input_ids": input_ids_list,
  275. "ids": ids,
  276. "prompt_text": prompt_text_list,
  277. "prompt_audio_list": prompt_audio_list,
  278. "prompt_text_after_apply_template": prompt_text_after_apply_template_list,
  279. "audio_processing_time": total_audio_processing_time,
  280. "speech_tokenization_time": total_speech_tokenization_time,
  281. "text_tokenization_time": total_text_tokenization_time,
  282. "chat_list": chat_list
  283. }
  284. def init_distributed():
  285. world_size = int(os.environ.get("WORLD_SIZE", 1))
  286. local_rank = int(os.environ.get("LOCAL_RANK", 0))
  287. rank = int(os.environ.get("RANK", 0))
  288. print(
  289. "Inference on multiple gpus, this gpu {}".format(local_rank)
  290. + ", rank {}, world_size {}".format(rank, world_size)
  291. )
  292. torch.cuda.set_device(local_rank)
  293. dist.init_process_group("nccl")
  294. return world_size, local_rank, rank
  295. def main(args):
  296. os.makedirs(args.output_dir, exist_ok=True)
  297. assert torch.cuda.is_available()
  298. local_rank, world_size, rank = 0, 1, 0
  299. device = torch.device(f"cuda:{local_rank}")
  300. tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
  301. if args.backend == "hf":
  302. model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
  303. model.eval()
  304. model.to(device)
  305. runner = None
  306. elif args.backend == "trtllm":
  307. if args.engine_dir is None:
  308. raise ValueError("--engine-dir is required when backend is 'trtllm'")
  309. runtime_rank = tensorrt_llm.mpi_rank()
  310. model = None
  311. runner_kwargs = dict(
  312. engine_dir=args.engine_dir,
  313. rank=runtime_rank,
  314. max_output_len=2048,
  315. enable_context_fmha_fp32_acc=False,
  316. max_batch_size=args.batch_size,
  317. max_input_len=512,
  318. kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
  319. cuda_graph_mode=False,
  320. gather_generation_logits=False,
  321. )
  322. runner = ModelRunnerCpp.from_dir(**runner_kwargs)
  323. elif args.backend == "vllm":
  324. model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
  325. runner = None
  326. elif args.backend == "trtllm-serve":
  327. model = None
  328. runner = None
  329. else:
  330. raise ValueError(f"Unsupported backend: {args.backend}")
  331. token2wav_model = CosyVoice2_Token2Wav(
  332. model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
  333. )
  334. if args.prompt_speech_path:
  335. prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
  336. else:
  337. prompt_speech_16k = None
  338. s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
  339. dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
  340. dataset = load_dataset(
  341. dataset_name,
  342. split=args.split_name,
  343. trust_remote_code=True,
  344. )
  345. sampler = None
  346. dataloader = DataLoader(
  347. dataset,
  348. batch_size=args.batch_size,
  349. sampler=sampler,
  350. shuffle=False,
  351. num_workers=args.num_workers,
  352. prefetch_factor=args.prefetch,
  353. collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
  354. )
  355. for _ in range(3):
  356. print(f"Running {_} times")
  357. total_llm_time = 0
  358. total_token2wav_time = 0
  359. total_data_load_time = 0
  360. total_llm_post_processing_time = 0
  361. total_audio_save_time = 0
  362. total_audio_processing_time_in_collator = 0
  363. total_speech_tokenization_time_in_collator = 0
  364. total_text_tokenization_time_in_collator = 0
  365. total_audio_samples = 0
  366. start_time = time.time()
  367. total_steps = len(dataset)
  368. if rank == 0:
  369. progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
  370. last_batch_end_time = time.time()
  371. for batch in dataloader:
  372. data_loaded_time = time.time()
  373. total_data_load_time += data_loaded_time - last_batch_end_time
  374. total_audio_processing_time_in_collator += batch["audio_processing_time"]
  375. total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
  376. total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
  377. with torch.no_grad():
  378. llm_start_time = time.time()
  379. if args.backend == "hf":
  380. input_ids_list = batch["input_ids"]
  381. if len(input_ids_list) == 1:
  382. input_ids = input_ids_list[0].unsqueeze(0)
  383. attention_mask = torch.ones_like(input_ids)
  384. else:
  385. max_len = max([len(input_ids) for input_ids in input_ids_list])
  386. input_ids_list_new = [
  387. torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
  388. for input_ids in input_ids_list
  389. ]
  390. input_ids = torch.stack(input_ids_list_new)
  391. attention_mask = torch.zeros_like(input_ids)
  392. for i in range(len(input_ids_list)):
  393. attention_mask[i, :len(input_ids_list[i])] = 1
  394. input_ids = input_ids.to(device)
  395. outputs = model.generate(
  396. input_ids=input_ids.to(device),
  397. attention_mask=attention_mask.to(device),
  398. max_new_tokens=2048,
  399. do_sample=True,
  400. top_p=args.top_p,
  401. temperature=args.temperature,
  402. repetition_penalty=1.1,
  403. top_k=args.top_k,
  404. )
  405. torch.cuda.synchronize()
  406. elif args.backend == "trtllm":
  407. batch_input_ids = list(batch["input_ids"])
  408. input_lengths = [x.size(0) for x in batch_input_ids]
  409. end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
  410. print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
  411. outputs = runner.generate(
  412. batch_input_ids=batch_input_ids,
  413. max_new_tokens=2048,
  414. end_id=end_id,
  415. pad_id=end_id,
  416. temperature=args.temperature,
  417. top_k=args.top_k,
  418. top_p=args.top_p,
  419. repetition_penalty=1.1,
  420. num_return_sequences=1,
  421. streaming=False,
  422. output_sequence_lengths=True,
  423. output_generation_logits=False,
  424. return_dict=True,
  425. return_all_generated_tokens=False
  426. )
  427. torch.cuda.synchronize()
  428. output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
  429. num_output_sents, num_beams, _ = output_ids.size()
  430. assert num_beams == 1
  431. beam = 0
  432. batch_size = len(batch["input_ids"])
  433. num_return_sequences = num_output_sents // batch_size
  434. assert num_return_sequences == 1
  435. outputs = []
  436. for i in range(batch_size * num_return_sequences):
  437. batch_idx = i // num_return_sequences
  438. seq_idx = i % num_return_sequences
  439. output_begin = input_lengths[batch_idx]
  440. output_end = sequence_lengths[i][beam]
  441. outputs_i = output_ids[i][beam][:output_end].tolist()
  442. outputs.append(outputs_i)
  443. elif args.backend == "vllm":
  444. input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
  445. sampling_params = SamplingParams(
  446. temperature=args.temperature,
  447. top_p=args.top_p,
  448. top_k=args.top_k,
  449. repetition_penalty=1.1,
  450. max_tokens=2048,
  451. )
  452. outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
  453. print(outputs)
  454. for j, output in enumerate(outputs):
  455. outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
  456. elif args.backend == "trtllm-serve":
  457. if args.batch_size > 1:
  458. outputs = asyncio.run(send_batch_requests_async(
  459. args.openai_api_base,
  460. args.openai_model_name,
  461. batch["chat_list"],
  462. args.temperature,
  463. args.top_p,
  464. args.top_k,
  465. ))
  466. else:
  467. outputs = []
  468. for i, chat in enumerate(batch["chat_list"]):
  469. payload = {
  470. "model": args.openai_model_name,
  471. "messages": chat,
  472. "max_tokens": 2048,
  473. "temperature": args.temperature,
  474. "top_p": args.top_p,
  475. "top_k": args.top_k,
  476. "repetition_penalty": 1.1,
  477. "stop": ["<|eos1|>", "<|eos|>"],
  478. "stream": False,
  479. }
  480. response = requests.post(args.openai_api_base, json=payload)
  481. response.raise_for_status()
  482. response_json = response.json()
  483. generated_content = response_json['choices'][0]['message']['content']
  484. outputs.append(generated_content)
  485. llm_end_time = time.time()
  486. total_llm_time += (llm_end_time - llm_start_time)
  487. items_for_token_2wav = []
  488. for i in range(len(batch["ids"])):
  489. llm_post_processing_start_time = time.time()
  490. if args.backend == "trtllm-serve":
  491. speech_tokens_str = outputs[i].strip().split('><')
  492. if len(speech_tokens_str) > 1:
  493. speech_tokens_str = [
  494. t if t.startswith('<') else '<' + t for t in speech_tokens_str
  495. ]
  496. speech_tokens_str = [
  497. t if t.endswith('>') else t + '>' for t in speech_tokens_str
  498. ]
  499. speech_ids = extract_speech_ids(speech_tokens_str)
  500. else:
  501. input_length = len(batch["input_ids"][i])
  502. generated_ids = outputs[i][input_length:]
  503. speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  504. speech_ids = extract_speech_ids(speech_tokens_str)
  505. print(i, speech_ids)
  506. if len(speech_ids) == 0:
  507. print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
  508. continue
  509. if args.prompt_text is not None:
  510. current_prompt_text = args.prompt_text
  511. current_prompt_audio = prompt_speech_16k
  512. else:
  513. current_prompt_text = batch["prompt_text"][i]
  514. current_prompt_audio = batch["prompt_audio_list"][i]
  515. llm_post_processing_end_time = time.time()
  516. total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
  517. if current_prompt_audio is not None:
  518. items_for_token_2wav.append({
  519. "speech_ids": speech_ids,
  520. "prompt_audio": current_prompt_audio.squeeze(0),
  521. "id": batch["ids"][i]
  522. })
  523. else:
  524. print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
  525. for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
  526. t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
  527. if not t2w_batch:
  528. continue
  529. t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
  530. t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
  531. t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
  532. t2w_ids = [item["id"] for item in t2w_batch]
  533. token2wav_start_time = time.time()
  534. generated_wavs = token2wav_model(
  535. t2w_generated_speech_tokens_list,
  536. t2w_prompt_audios_list,
  537. t2w_prompt_audios_sample_rate,
  538. )
  539. torch.cuda.synchronize()
  540. token2wav_end_time = time.time()
  541. total_token2wav_time += (token2wav_end_time - token2wav_start_time)
  542. audio_save_start_time = time.time()
  543. for j, audio_hat in enumerate(generated_wavs):
  544. generated_wave = audio_hat.squeeze().cpu().numpy()
  545. total_audio_samples += len(generated_wave)
  546. target_sample_rate = 24000
  547. utt = t2w_ids[j]
  548. sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
  549. print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
  550. audio_save_end_time = time.time()
  551. total_audio_save_time += audio_save_end_time - audio_save_start_time
  552. if rank == 0:
  553. progress_bar.update(world_size * len(batch["ids"]))
  554. last_batch_end_time = time.time()
  555. if rank == 0:
  556. progress_bar.close()
  557. end_time = time.time()
  558. target_sample_rate = 24000
  559. total_audio_duration_seconds = total_audio_samples / target_sample_rate
  560. log_file_path = os.path.join(args.output_dir, "log.txt")
  561. with open(log_file_path, 'w') as f:
  562. args_dict = vars(args)
  563. log_data = {
  564. "args": args_dict,
  565. "data_load_time_seconds": total_data_load_time,
  566. "audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
  567. "speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
  568. "text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
  569. "llm_time_seconds": total_llm_time,
  570. "llm_post_processing_time_seconds": total_llm_post_processing_time,
  571. "token2wav_time_seconds": total_token2wav_time,
  572. "audio_save_time_seconds": total_audio_save_time,
  573. "total_audio_duration_seconds": total_audio_duration_seconds,
  574. "pipeline_time_seconds": end_time - start_time,
  575. }
  576. print(log_data)
  577. f.write(json.dumps(log_data, indent=4))
  578. print(f"Metrics logged to {log_file_path}")
  579. if __name__ == "__main__":
  580. args = get_args()
  581. if args.backend == "vllm":
  582. from vllm import LLM, SamplingParams
  583. elif args.backend == "trtllm":
  584. import tensorrt_llm
  585. from tensorrt_llm.runtime import ModelRunnerCpp
  586. elif args.backend == "hf":
  587. from transformers import AutoModelForCausalLM
  588. elif args.backend == "trtllm-serve":
  589. pass
  590. else:
  591. raise ValueError(f"Unsupported backend: {args.backend}")
  592. main(args)