offline_inference.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Example Usage
  16. CUDA_VISIBLE_DEVICES=0 \
  17. python3 offline_inference.py \
  18. --output-dir $output_dir \
  19. --llm-model-name-or-path $huggingface_model_local_dir \
  20. --token2wav-path $model_scope_model_local_dir \
  21. --backend $backend \
  22. --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
  23. --engine-dir $trt_engines_dir \
  24. --split-name ${dataset} || exit 1
  25. """
  26. import argparse
  27. import json
  28. import os
  29. import sys
  30. from pathlib import Path
  31. import torch
  32. import torch.distributed as dist
  33. import torch.nn.functional as F
  34. import torchaudio
  35. from cosyvoice.utils.file_utils import load_wav
  36. from datasets import load_dataset
  37. from transformers import AutoTokenizer
  38. from torch.utils.data import DataLoader, Dataset
  39. from tqdm import tqdm
  40. import soundfile as sf
  41. import s3tokenizer
  42. from functools import partial
  43. import time
  44. from token2wav import CosyVoice2_Token2Wav
  45. sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
  46. try:
  47. torch.multiprocessing.set_start_method("spawn")
  48. except RuntimeError:
  49. pass
  50. def extract_speech_ids(speech_tokens_str):
  51. """Extract speech IDs from token strings like <|s_23456|>"""
  52. speech_ids = []
  53. for token_str in speech_tokens_str:
  54. if token_str.startswith('<|s_') and token_str.endswith('|>'):
  55. num_str = token_str[4:-2]
  56. num = int(num_str)
  57. speech_ids.append(num)
  58. else:
  59. print(f"Unexpected token: {token_str}")
  60. return speech_ids
  61. def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
  62. """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
  63. speech_id_str = ""
  64. for token in cosy2_tokens:
  65. speech_id_str += f"<|s_{token}|>"
  66. return speech_id_str
  67. def get_args():
  68. parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
  69. parser.add_argument(
  70. "--split-name",
  71. type=str,
  72. default="wenetspeech4tts",
  73. help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
  74. )
  75. parser.add_argument(
  76. "--output-dir", required=True, type=str, help="dir to save result"
  77. )
  78. parser.add_argument(
  79. "--batch-size",
  80. default=1,
  81. type=int,
  82. help="batch size (per-device) for inference",
  83. )
  84. parser.add_argument(
  85. "--token2wav-batch-size",
  86. default=1,
  87. type=int,
  88. help="batch size (per-device) for inference",
  89. )
  90. parser.add_argument(
  91. "--num-workers", type=int, default=0, help="workers for dataloader"
  92. )
  93. parser.add_argument(
  94. "--prefetch", type=int, default=None, help="prefetch for dataloader"
  95. )
  96. parser.add_argument(
  97. "--llm-model-name-or-path",
  98. required=True,
  99. type=str,
  100. help="LLM model path (includes both model and tokenizer)",
  101. )
  102. parser.add_argument(
  103. "--token2wav-path",
  104. required=True,
  105. type=str,
  106. help="CosyVoice2 token2wav model path",
  107. )
  108. parser.add_argument(
  109. "--prompt-text",
  110. type=str,
  111. default=None,
  112. help="The prompt text for CosyVoice2",
  113. )
  114. parser.add_argument(
  115. "--prompt-speech-path",
  116. type=str,
  117. default=None,
  118. help="The path to the prompt speech for CosyVoice2",
  119. )
  120. parser.add_argument(
  121. "--top-p",
  122. type=float,
  123. default=0.95,
  124. help="top p for sampling",
  125. )
  126. parser.add_argument(
  127. "--temperature",
  128. type=float,
  129. default=0.8,
  130. help="temperature for sampling",
  131. )
  132. parser.add_argument(
  133. "--top-k",
  134. type=int,
  135. default=50,
  136. help="top k for sampling",
  137. )
  138. parser.add_argument(
  139. "--backend",
  140. type=str,
  141. default="hf",
  142. choices=["hf", "trtllm", "vllm"],
  143. help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
  144. )
  145. parser.add_argument(
  146. "--engine-dir",
  147. type=str,
  148. default=None,
  149. help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
  150. )
  151. parser.add_argument(
  152. "--kv-cache-free-gpu-memory-fraction",
  153. type=float,
  154. default=0.6,
  155. help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
  156. )
  157. args = parser.parse_args()
  158. return args
  159. def data_collator(batch, tokenizer, s3_tokenizer):
  160. """Simplified data collator for batch_size=1 processing"""
  161. collator_start_time = time.time()
  162. total_audio_processing_time = 0
  163. total_speech_tokenization_time = 0
  164. total_text_tokenization_time = 0
  165. target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
  166. device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
  167. input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
  168. prompt_text_after_apply_template_list = []
  169. mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
  170. for _, item in enumerate(batch):
  171. audio_processing_start_time = time.time()
  172. prompt_text, target_text = (
  173. item["prompt_text"],
  174. item["target_text"],
  175. )
  176. prompt_text_list.append(prompt_text)
  177. full_text = prompt_text + target_text
  178. full_text_list.append(full_text)
  179. # remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
  180. puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
  181. for p in puncts:
  182. if p in full_text:
  183. full_text = full_text.replace(p, '')
  184. print(f"removed {p} from {full_text}")
  185. # get prompt audio for CosyVoice2 (convert to 16kHz)
  186. ref_audio_org, ref_sr = (
  187. item["prompt_audio"]["array"],
  188. item["prompt_audio"]["sampling_rate"],
  189. )
  190. ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
  191. print(ref_audio_org.shape)
  192. if ref_sr != target_sample_rate:
  193. resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
  194. ref_audio = resampler(ref_audio_org)
  195. else:
  196. ref_audio = ref_audio_org
  197. prompt_audio_list.append(ref_audio)
  198. audio_processing_end_time = time.time()
  199. total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
  200. speech_tokenization_start_time = time.time()
  201. if "prompt_audio_cosy2_tokens" in item:
  202. prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
  203. prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
  204. else:
  205. mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
  206. if len(mels) > 0:
  207. mels, mels_lens = s3tokenizer.padding(mels)
  208. codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
  209. for i in range(len(codes)):
  210. prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
  211. speech_tokenization_end_time = time.time()
  212. total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
  213. for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
  214. text_tokenization_start_time = time.time()
  215. prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
  216. # Create chat template for LLM generation
  217. chat = [
  218. {"role": "user", "content": full_text_list[i]},
  219. {"role": "assistant", "content": prompt_audio_cosy2_id_str}
  220. ]
  221. assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
  222. input_ids = tokenizer.apply_chat_template(
  223. chat,
  224. tokenize=True,
  225. return_tensors='pt',
  226. continue_final_message=True
  227. )
  228. input_ids_list.append(input_ids.squeeze(0))
  229. prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
  230. prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
  231. text_tokenization_end_time = time.time()
  232. total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
  233. ids = [item["id"] for item in batch]
  234. return {
  235. "input_ids": input_ids_list,
  236. "ids": ids,
  237. "prompt_text": prompt_text_list,
  238. "prompt_audio_list": prompt_audio_list,
  239. "prompt_text_after_apply_template": prompt_text_after_apply_template_list,
  240. "audio_processing_time": total_audio_processing_time,
  241. "speech_tokenization_time": total_speech_tokenization_time,
  242. "text_tokenization_time": total_text_tokenization_time,
  243. }
  244. def init_distributed():
  245. world_size = int(os.environ.get("WORLD_SIZE", 1))
  246. local_rank = int(os.environ.get("LOCAL_RANK", 0))
  247. rank = int(os.environ.get("RANK", 0))
  248. print(
  249. "Inference on multiple gpus, this gpu {}".format(local_rank)
  250. + ", rank {}, world_size {}".format(rank, world_size)
  251. )
  252. torch.cuda.set_device(local_rank)
  253. dist.init_process_group("nccl")
  254. return world_size, local_rank, rank
  255. def main(args):
  256. os.makedirs(args.output_dir, exist_ok=True)
  257. assert torch.cuda.is_available()
  258. local_rank, world_size, rank = 0, 1, 0
  259. device = torch.device(f"cuda:{local_rank}")
  260. tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
  261. if args.backend == "hf":
  262. model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
  263. model.eval()
  264. model.to(device)
  265. runner = None
  266. elif args.backend == "trtllm":
  267. if args.engine_dir is None:
  268. raise ValueError("--engine-dir is required when backend is 'trtllm'")
  269. runtime_rank = tensorrt_llm.mpi_rank()
  270. model = None
  271. runner_kwargs = dict(
  272. engine_dir=args.engine_dir,
  273. rank=runtime_rank,
  274. max_output_len=2048,
  275. enable_context_fmha_fp32_acc=False,
  276. max_batch_size=args.batch_size,
  277. max_input_len=512,
  278. kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
  279. cuda_graph_mode=False,
  280. gather_generation_logits=False,
  281. )
  282. runner = ModelRunnerCpp.from_dir(**runner_kwargs)
  283. elif args.backend == "vllm":
  284. model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
  285. runner = None
  286. else:
  287. raise ValueError(f"Unsupported backend: {args.backend}")
  288. token2wav_model = CosyVoice2_Token2Wav(
  289. model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
  290. )
  291. if args.prompt_speech_path:
  292. prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
  293. else:
  294. prompt_speech_16k = None
  295. s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
  296. dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
  297. dataset = load_dataset(
  298. dataset_name,
  299. split=args.split_name,
  300. trust_remote_code=True,
  301. )
  302. sampler = None
  303. dataloader = DataLoader(
  304. dataset,
  305. batch_size=args.batch_size,
  306. sampler=sampler,
  307. shuffle=False,
  308. num_workers=args.num_workers,
  309. prefetch_factor=args.prefetch,
  310. collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
  311. )
  312. for _ in range(3):
  313. print(f"Running {_} times")
  314. total_llm_time = 0
  315. total_token2wav_time = 0
  316. total_data_load_time = 0
  317. total_llm_post_processing_time = 0
  318. total_audio_save_time = 0
  319. total_audio_processing_time_in_collator = 0
  320. total_speech_tokenization_time_in_collator = 0
  321. total_text_tokenization_time_in_collator = 0
  322. total_audio_samples = 0
  323. start_time = time.time()
  324. total_steps = len(dataset)
  325. if rank == 0:
  326. progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
  327. last_batch_end_time = time.time()
  328. for batch in dataloader:
  329. data_loaded_time = time.time()
  330. total_data_load_time += data_loaded_time - last_batch_end_time
  331. total_audio_processing_time_in_collator += batch["audio_processing_time"]
  332. total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
  333. total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
  334. with torch.no_grad():
  335. llm_start_time = time.time()
  336. if args.backend == "hf":
  337. input_ids_list = batch["input_ids"]
  338. if len(input_ids_list) == 1:
  339. input_ids = input_ids_list[0].unsqueeze(0)
  340. attention_mask = torch.ones_like(input_ids)
  341. else:
  342. max_len = max([len(input_ids) for input_ids in input_ids_list])
  343. input_ids_list_new = [
  344. torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
  345. for input_ids in input_ids_list
  346. ]
  347. input_ids = torch.stack(input_ids_list_new)
  348. attention_mask = torch.zeros_like(input_ids)
  349. for i in range(len(input_ids_list)):
  350. attention_mask[i, :len(input_ids_list[i])] = 1
  351. input_ids = input_ids.to(device)
  352. outputs = model.generate(
  353. input_ids=input_ids.to(device),
  354. attention_mask=attention_mask.to(device),
  355. max_new_tokens=2048,
  356. do_sample=True,
  357. top_p=args.top_p,
  358. temperature=args.temperature,
  359. repetition_penalty=1.1,
  360. top_k=args.top_k,
  361. )
  362. torch.cuda.synchronize()
  363. elif args.backend == "trtllm":
  364. batch_input_ids = list(batch["input_ids"])
  365. input_lengths = [x.size(0) for x in batch_input_ids]
  366. end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
  367. print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
  368. outputs = runner.generate(
  369. batch_input_ids=batch_input_ids,
  370. max_new_tokens=2048,
  371. end_id=end_id,
  372. pad_id=end_id,
  373. temperature=args.temperature,
  374. top_k=args.top_k,
  375. top_p=args.top_p,
  376. repetition_penalty=1.1,
  377. num_return_sequences=1,
  378. streaming=False,
  379. output_sequence_lengths=True,
  380. output_generation_logits=False,
  381. return_dict=True,
  382. return_all_generated_tokens=False
  383. )
  384. torch.cuda.synchronize()
  385. output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
  386. num_output_sents, num_beams, _ = output_ids.size()
  387. assert num_beams == 1
  388. beam = 0
  389. batch_size = len(batch["input_ids"])
  390. num_return_sequences = num_output_sents // batch_size
  391. assert num_return_sequences == 1
  392. outputs = []
  393. for i in range(batch_size * num_return_sequences):
  394. batch_idx = i // num_return_sequences
  395. seq_idx = i % num_return_sequences
  396. output_begin = input_lengths[batch_idx]
  397. output_end = sequence_lengths[i][beam]
  398. outputs_i = output_ids[i][beam][:output_end].tolist()
  399. outputs.append(outputs_i)
  400. elif args.backend == "vllm":
  401. input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
  402. sampling_params = SamplingParams(
  403. temperature=args.temperature,
  404. top_p=args.top_p,
  405. top_k=args.top_k,
  406. repetition_penalty=1.1,
  407. max_tokens=2048,
  408. )
  409. outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
  410. print(outputs)
  411. for j, output in enumerate(outputs):
  412. outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
  413. llm_end_time = time.time()
  414. total_llm_time += (llm_end_time - llm_start_time)
  415. items_for_token_2wav = []
  416. for i in range(len(batch["ids"])):
  417. llm_post_processing_start_time = time.time()
  418. input_length = len(batch["input_ids"][i])
  419. generated_ids = outputs[i][input_length:]
  420. speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  421. speech_ids = extract_speech_ids(speech_tokens_str)
  422. print(i, speech_ids)
  423. if len(speech_ids) == 0:
  424. print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
  425. continue
  426. if args.prompt_text is not None:
  427. current_prompt_text = args.prompt_text
  428. current_prompt_audio = prompt_speech_16k
  429. else:
  430. current_prompt_text = batch["prompt_text"][i]
  431. current_prompt_audio = batch["prompt_audio_list"][i]
  432. llm_post_processing_end_time = time.time()
  433. total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
  434. if current_prompt_audio is not None:
  435. items_for_token_2wav.append({
  436. "speech_ids": speech_ids,
  437. "prompt_audio": current_prompt_audio.squeeze(0),
  438. "id": batch["ids"][i]
  439. })
  440. else:
  441. print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
  442. for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
  443. t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
  444. if not t2w_batch:
  445. continue
  446. t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
  447. t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
  448. t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
  449. t2w_ids = [item["id"] for item in t2w_batch]
  450. token2wav_start_time = time.time()
  451. generated_wavs = token2wav_model(
  452. t2w_generated_speech_tokens_list,
  453. t2w_prompt_audios_list,
  454. t2w_prompt_audios_sample_rate,
  455. )
  456. torch.cuda.synchronize()
  457. token2wav_end_time = time.time()
  458. total_token2wav_time += (token2wav_end_time - token2wav_start_time)
  459. audio_save_start_time = time.time()
  460. for j, audio_hat in enumerate(generated_wavs):
  461. generated_wave = audio_hat.squeeze().cpu().numpy()
  462. total_audio_samples += len(generated_wave)
  463. target_sample_rate = 24000
  464. utt = t2w_ids[j]
  465. sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
  466. print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
  467. audio_save_end_time = time.time()
  468. total_audio_save_time += audio_save_end_time - audio_save_start_time
  469. if rank == 0:
  470. progress_bar.update(world_size * len(batch["ids"]))
  471. last_batch_end_time = time.time()
  472. if rank == 0:
  473. progress_bar.close()
  474. end_time = time.time()
  475. target_sample_rate = 24000
  476. total_audio_duration_seconds = total_audio_samples / target_sample_rate
  477. log_file_path = os.path.join(args.output_dir, "log.txt")
  478. with open(log_file_path, 'w') as f:
  479. args_dict = vars(args)
  480. log_data = {
  481. "args": args_dict,
  482. "data_load_time_seconds": total_data_load_time,
  483. "audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
  484. "speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
  485. "text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
  486. "llm_time_seconds": total_llm_time,
  487. "llm_post_processing_time_seconds": total_llm_post_processing_time,
  488. "token2wav_time_seconds": total_token2wav_time,
  489. "audio_save_time_seconds": total_audio_save_time,
  490. "total_audio_duration_seconds": total_audio_duration_seconds,
  491. "pipeline_time_seconds": end_time - start_time,
  492. }
  493. print(log_data)
  494. f.write(json.dumps(log_data, indent=4))
  495. print(f"Metrics logged to {log_file_path}")
  496. if __name__ == "__main__":
  497. args = get_args()
  498. if args.backend == "vllm":
  499. from vllm import LLM, SamplingParams
  500. elif args.backend == "trtllm":
  501. import tensorrt_llm
  502. from tensorrt_llm.runtime import ModelRunnerCpp
  503. elif args.backend == "hf":
  504. from transformers import AutoModelForCausalLM
  505. else:
  506. raise ValueError(f"Unsupported backend: {args.backend}")
  507. main(args)