infer_cosyvoice3.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. """ Example Usage
  2. CUDA_VISIBLE_DEVICES=0 \
  3. python3 infer_cosyvoice3_token2wav.py \
  4. --output-dir $output_dir \
  5. --llm-model-name-or-path $huggingface_model_local_dir \
  6. --token2wav-path $token2wav_model_dir \
  7. --backend $backend \
  8. --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
  9. --engine-dir $trt_engines_dir \
  10. --split-name ${dataset} || exit 1
  11. """
  12. import argparse
  13. import json
  14. import os
  15. import time
  16. import asyncio
  17. import torch
  18. import torchaudio
  19. import s3tokenizer
  20. import soundfile as sf
  21. import requests
  22. import httpx
  23. from transformers import AutoTokenizer
  24. from datasets import load_dataset
  25. from torch.utils.data import DataLoader
  26. from functools import partial
  27. from tqdm import tqdm
  28. from token2wav_cosyvoice3 import CosyVoice3_Token2Wav
  29. try:
  30. torch.multiprocessing.set_start_method("spawn")
  31. except RuntimeError:
  32. pass
  33. async def send_request_async(client, url, payload):
  34. response = await client.post(url, json=payload, timeout=None)
  35. response.raise_for_status()
  36. response_json = response.json()
  37. return response_json['choices'][0]['message']['content']
  38. async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
  39. async with httpx.AsyncClient() as client:
  40. tasks = []
  41. for chat in chats:
  42. payload = {
  43. "model": model_name,
  44. "messages": chat,
  45. "max_tokens": 2048,
  46. "temperature": temperature,
  47. "top_p": top_p,
  48. "top_k": top_k,
  49. "repetition_penalty": 1.1,
  50. "stop": ["<|eos1|>", "<|eos|>"],
  51. "stream": False,
  52. }
  53. tasks.append(send_request_async(client, api_base, payload))
  54. return await asyncio.gather(*tasks)
  55. def extract_speech_ids(speech_tokens_str):
  56. """Extract speech IDs from token strings like <|s_23456|>"""
  57. speech_ids = []
  58. for token_str in speech_tokens_str:
  59. if token_str.startswith('<|s_') and token_str.endswith('|>'):
  60. num_str = token_str[4:-2]
  61. num = int(num_str)
  62. speech_ids.append(num)
  63. else:
  64. print(f"Unexpected token: {token_str}")
  65. return speech_ids
  66. def convert_cosy3_tokens_to_speech_id_str(cosy3_tokens):
  67. """Convert CosyVoice3 tokens to speech IDs string like <|s_23456|>"""
  68. if hasattr(cosy3_tokens, 'cpu'):
  69. cosy3_tokens = cosy3_tokens.cpu().numpy().tolist()
  70. speech_id_str = ""
  71. for token in cosy3_tokens:
  72. speech_id_str += f"<|s_{token}|>"
  73. return speech_id_str
  74. def get_args():
  75. parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice3")
  76. parser.add_argument(
  77. "--split-name", type=str, default="wenetspeech4tts",
  78. help="huggingface dataset split name",
  79. )
  80. parser.add_argument(
  81. "--output-dir", required=True, type=str, help="dir to save result",
  82. )
  83. parser.add_argument(
  84. "--batch-size", default=1, type=int,
  85. help="batch size (per-device) for LLM inference",
  86. )
  87. parser.add_argument(
  88. "--token2wav-batch-size", default=1, type=int,
  89. help="batch size (per-device) for token2wav inference",
  90. )
  91. parser.add_argument(
  92. "--num-workers", type=int, default=0, help="workers for dataloader",
  93. )
  94. parser.add_argument(
  95. "--prefetch", type=int, default=None, help="prefetch for dataloader",
  96. )
  97. parser.add_argument(
  98. "--llm-model-name-or-path", required=True, type=str,
  99. help="CosyVoice3 HF LLM path (e.g. ./hf_cosyvoice3_llm)",
  100. )
  101. parser.add_argument(
  102. "--token2wav-path", required=True, type=str,
  103. help="CosyVoice3 model path (e.g. /workspace_yuekai/HF/Fun-CosyVoice3-0.5B-2512)",
  104. )
  105. parser.add_argument(
  106. "--enable-trt", action="store_true",
  107. help="Enable TensorRT for flow decoder estimator",
  108. )
  109. parser.add_argument(
  110. "--streaming", action="store_true",
  111. help="Enable streaming for flow decoder estimator",
  112. )
  113. parser.add_argument(
  114. "--top-p", type=float, default=0.95, help="top p for sampling",
  115. )
  116. parser.add_argument(
  117. "--temperature", type=float, default=0.8, help="temperature for sampling",
  118. )
  119. parser.add_argument(
  120. "--top-k", type=int, default=15, help="top k for sampling",
  121. )
  122. parser.add_argument(
  123. "--backend", type=str, default="hf",
  124. choices=["hf", "trtllm", "vllm", "trtllm-serve"],
  125. help="Backend to use for LLM inference",
  126. )
  127. parser.add_argument(
  128. "--engine-dir", type=str, default=None,
  129. help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
  130. )
  131. parser.add_argument(
  132. "--kv-cache-free-gpu-memory-fraction", type=float, default=0.6,
  133. help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
  134. )
  135. parser.add_argument(
  136. "--openai-api-base", type=str,
  137. default="http://localhost:8000/v1/chat/completions",
  138. help="OpenAI API base URL (for trtllm-serve backend)",
  139. )
  140. parser.add_argument(
  141. "--openai-model-name", type=str, default="trt_engines_bfloat16",
  142. help="Model name to use with OpenAI API (for trtllm-serve backend)",
  143. )
  144. parser.add_argument(
  145. "--epoch", type=int, default=1, help="Epoch to run",
  146. )
  147. return parser.parse_args()
  148. def data_collator(batch, tokenizer, s3_tokenizer):
  149. """Data collator: extracts cosy3 tokens from prompt_audio using v3 s3 tokenizer."""
  150. device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
  151. target_sample_rate = 16000
  152. input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
  153. mels, prompt_audio_cosy3tokens_list, full_text_list = [], [], []
  154. chat_list = []
  155. for item in batch:
  156. prompt_text, target_text = item["prompt_text"], item["target_text"]
  157. prompt_text_list.append(prompt_text)
  158. full_text = 'You are a helpful assistant.<|endofprompt|>' + prompt_text + target_text
  159. full_text_list.append(full_text)
  160. # Get prompt audio (convert to 16kHz for s3 tokenizer)
  161. ref_audio = torch.from_numpy(item["prompt_audio"]["array"]).float().unsqueeze(0)
  162. ref_sr = item["prompt_audio"]["sampling_rate"]
  163. if ref_sr != target_sample_rate:
  164. ref_audio = torchaudio.transforms.Resample(ref_sr, target_sample_rate)(ref_audio)
  165. prompt_audio_list.append(ref_audio)
  166. # Extract cosy3 tokens from prompt_audio using v3 s3 tokenizer
  167. mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
  168. # Batch tokenization with v3 tokenizer
  169. if len(mels) > 0:
  170. mels_padded, mels_lens = s3tokenizer.padding(mels)
  171. codes, codes_lens = s3_tokenizer.quantize(mels_padded.to(device), mels_lens.to(device))
  172. for i in range(len(codes)):
  173. prompt_audio_cosy3tokens_list.append(codes[i, :codes_lens[i].item()])
  174. # Build LLM inputs
  175. for i, prompt_audio_cosy3tokens in enumerate(prompt_audio_cosy3tokens_list):
  176. prompt_audio_cosy3_id_str = convert_cosy3_tokens_to_speech_id_str(
  177. prompt_audio_cosy3tokens)
  178. chat = [
  179. {"role": "user", "content": full_text_list[i]},
  180. {"role": "assistant", "content": prompt_audio_cosy3_id_str}
  181. ]
  182. chat_list.append(chat)
  183. input_ids = tokenizer.apply_chat_template(
  184. chat, tokenize=True, return_tensors='pt', continue_final_message=True)
  185. input_ids_list.append(input_ids.squeeze(0))
  186. ids = [item["id"] for item in batch]
  187. return {
  188. "input_ids": input_ids_list,
  189. "ids": ids,
  190. "prompt_text": prompt_text_list,
  191. "prompt_audio_list": prompt_audio_list,
  192. "chat_list": chat_list,
  193. }
  194. def main(args):
  195. os.makedirs(args.output_dir, exist_ok=True)
  196. assert torch.cuda.is_available()
  197. local_rank = 0
  198. device = torch.device(f"cuda:{local_rank}")
  199. tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
  200. if args.backend == "hf":
  201. model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
  202. model.eval()
  203. model.to(device)
  204. runner = None
  205. elif args.backend == "trtllm":
  206. if args.engine_dir is None:
  207. raise ValueError("--engine-dir is required when backend is 'trtllm'")
  208. runtime_rank = tensorrt_llm.mpi_rank()
  209. model = None
  210. runner_kwargs = dict(
  211. engine_dir=args.engine_dir,
  212. rank=runtime_rank,
  213. max_output_len=2048,
  214. enable_context_fmha_fp32_acc=False,
  215. max_batch_size=args.batch_size,
  216. max_input_len=512,
  217. kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
  218. cuda_graph_mode=False,
  219. gather_generation_logits=False,
  220. )
  221. runner = ModelRunnerCpp.from_dir(**runner_kwargs)
  222. elif args.backend == "vllm":
  223. model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
  224. runner = None
  225. elif args.backend == "trtllm-serve":
  226. model = None
  227. runner = None
  228. else:
  229. raise ValueError(f"Unsupported backend: {args.backend}")
  230. token2wav_model = CosyVoice3_Token2Wav(
  231. model_dir=args.token2wav_path, enable_trt=args.enable_trt, device_id=local_rank, streaming=args.streaming
  232. )
  233. # Load v3 s3 tokenizer for prompt audio tokenization in data_collator
  234. s3_tokenizer = s3tokenizer.load_model(
  235. f"{args.token2wav_path}/speech_tokenizer_v3.onnx"
  236. ).to(device).eval()
  237. dataset = load_dataset(
  238. "yuekai/seed_tts_cosy2",
  239. split=args.split_name,
  240. trust_remote_code=True,
  241. )
  242. dataloader = DataLoader(
  243. dataset,
  244. batch_size=args.batch_size,
  245. shuffle=False,
  246. num_workers=args.num_workers,
  247. prefetch_factor=args.prefetch,
  248. collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
  249. )
  250. for epoch in range(args.epoch):
  251. print(f"Running epoch {epoch}")
  252. total_llm_time = 0
  253. total_token2wav_time = 0
  254. total_data_load_time = 0
  255. total_llm_post_processing_time = 0
  256. total_audio_save_time = 0
  257. total_audio_samples = 0
  258. start_time = time.time()
  259. progress_bar = tqdm(total=len(dataset), desc="Processing", unit="wavs")
  260. last_batch_end_time = time.time()
  261. for batch in dataloader:
  262. data_loaded_time = time.time()
  263. total_data_load_time += data_loaded_time - last_batch_end_time
  264. with torch.no_grad():
  265. llm_start_time = time.time()
  266. if args.backend == "hf":
  267. input_ids_list = batch["input_ids"]
  268. if len(input_ids_list) == 1:
  269. input_ids = input_ids_list[0].unsqueeze(0)
  270. attention_mask = torch.ones_like(input_ids)
  271. else:
  272. max_len = max([len(ids) for ids in input_ids_list])
  273. input_ids_list_new = [
  274. torch.cat([ids, torch.full((max_len - len(ids),), tokenizer.pad_token_id)])
  275. for ids in input_ids_list
  276. ]
  277. input_ids = torch.stack(input_ids_list_new)
  278. attention_mask = torch.zeros_like(input_ids)
  279. for i in range(len(input_ids_list)):
  280. attention_mask[i, :len(input_ids_list[i])] = 1
  281. outputs = model.generate(
  282. input_ids=input_ids.to(device),
  283. attention_mask=attention_mask.to(device),
  284. max_new_tokens=2048,
  285. do_sample=True,
  286. top_p=args.top_p,
  287. temperature=args.temperature,
  288. repetition_penalty=1.1,
  289. top_k=args.top_k,
  290. )
  291. torch.cuda.synchronize()
  292. elif args.backend == "trtllm":
  293. batch_input_ids = list(batch["input_ids"])
  294. input_lengths = [x.size(0) for x in batch_input_ids]
  295. end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
  296. outputs = runner.generate(
  297. batch_input_ids=batch_input_ids,
  298. max_new_tokens=2048,
  299. end_id=end_id,
  300. pad_id=end_id,
  301. temperature=args.temperature,
  302. top_k=args.top_k,
  303. top_p=args.top_p,
  304. repetition_penalty=1.1,
  305. num_return_sequences=1,
  306. streaming=False,
  307. output_sequence_lengths=True,
  308. output_generation_logits=False,
  309. return_dict=True,
  310. return_all_generated_tokens=False
  311. )
  312. torch.cuda.synchronize()
  313. output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
  314. num_output_sents, num_beams, _ = output_ids.size()
  315. assert num_beams == 1
  316. batch_size = len(batch["input_ids"])
  317. num_return_sequences = num_output_sents // batch_size
  318. assert num_return_sequences == 1
  319. outputs = []
  320. for i in range(batch_size * num_return_sequences):
  321. batch_idx = i // num_return_sequences
  322. output_begin = input_lengths[batch_idx]
  323. output_end = sequence_lengths[i][0]
  324. outputs_i = output_ids[i][0][:output_end].tolist()
  325. outputs.append(outputs_i)
  326. elif args.backend == "vllm":
  327. input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
  328. sampling_params = SamplingParams(
  329. temperature=args.temperature,
  330. top_p=args.top_p,
  331. top_k=args.top_k,
  332. repetition_penalty=1.1,
  333. max_tokens=2048,
  334. )
  335. outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
  336. for j, output in enumerate(outputs):
  337. outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
  338. elif args.backend == "trtllm-serve":
  339. if args.batch_size > 1:
  340. outputs = asyncio.run(send_batch_requests_async(
  341. args.openai_api_base,
  342. args.openai_model_name,
  343. batch["chat_list"],
  344. args.temperature,
  345. args.top_p,
  346. args.top_k,
  347. ))
  348. else:
  349. outputs = []
  350. for chat in batch["chat_list"]:
  351. payload = {
  352. "model": args.openai_model_name,
  353. "messages": chat,
  354. "max_tokens": 2048,
  355. "temperature": args.temperature,
  356. "top_p": args.top_p,
  357. "top_k": args.top_k,
  358. "repetition_penalty": 1.1,
  359. "stop": ["<|eos1|>", "<|eos|>"],
  360. "stream": False,
  361. }
  362. response = requests.post(args.openai_api_base, json=payload)
  363. response.raise_for_status()
  364. response_json = response.json()
  365. generated_content = response_json['choices'][0]['message']['content']
  366. outputs.append(generated_content)
  367. llm_end_time = time.time()
  368. total_llm_time += (llm_end_time - llm_start_time)
  369. items_for_token_2wav = []
  370. for i in range(len(batch["ids"])):
  371. llm_post_processing_start_time = time.time()
  372. if args.backend == "trtllm-serve":
  373. speech_tokens_str = outputs[i].strip().split('><')
  374. if len(speech_tokens_str) > 1:
  375. speech_tokens_str = [
  376. t if t.startswith('<') else '<' + t for t in speech_tokens_str
  377. ]
  378. speech_tokens_str = [
  379. t if t.endswith('>') else t + '>' for t in speech_tokens_str
  380. ]
  381. speech_ids = extract_speech_ids(speech_tokens_str)
  382. else:
  383. input_length = len(batch["input_ids"][i])
  384. generated_ids = outputs[i][input_length:]
  385. speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  386. speech_ids = extract_speech_ids(speech_tokens_str)
  387. print(i, speech_ids[:10], "...", f"total={len(speech_ids)}")
  388. if len(speech_ids) == 0:
  389. print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
  390. llm_post_processing_end_time = time.time()
  391. total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
  392. continue
  393. current_prompt_audio = batch["prompt_audio_list"][i]
  394. llm_post_processing_end_time = time.time()
  395. total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
  396. items_for_token_2wav.append({
  397. "speech_ids": speech_ids,
  398. "prompt_audio": current_prompt_audio.squeeze(0),
  399. "id": batch["ids"][i]
  400. })
  401. for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
  402. t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
  403. if not t2w_batch:
  404. continue
  405. t2w_speech_tokens = [item["speech_ids"] for item in t2w_batch]
  406. t2w_prompt_audios = [item["prompt_audio"] for item in t2w_batch]
  407. t2w_sample_rates = [16000] * len(t2w_batch)
  408. token2wav_start_time = time.time()
  409. generated_wavs = token2wav_model(
  410. t2w_speech_tokens, t2w_prompt_audios, t2w_sample_rates,
  411. streaming=args.streaming,
  412. )
  413. token2wav_end_time = time.time()
  414. total_token2wav_time += (token2wav_end_time - token2wav_start_time)
  415. audio_save_start_time = time.time()
  416. for j, audio_hat in enumerate(generated_wavs):
  417. wav = audio_hat.squeeze().cpu().numpy()
  418. total_audio_samples += len(wav)
  419. sf.write(f"{args.output_dir}/{t2w_batch[j]['id']}.wav", wav, 24000)
  420. print(f"Generated audio for sample {t2w_batch[j]['id']} with {len(t2w_speech_tokens[j])} tokens")
  421. audio_save_end_time = time.time()
  422. total_audio_save_time += audio_save_end_time - audio_save_start_time
  423. progress_bar.update(len(batch["ids"]))
  424. last_batch_end_time = time.time()
  425. progress_bar.close()
  426. end_time = time.time()
  427. total_audio_duration_seconds = total_audio_samples / 24000
  428. log_file_path = os.path.join(args.output_dir, "log.txt")
  429. with open(log_file_path, 'w') as f:
  430. log_data = {
  431. "args": vars(args),
  432. "data_load_time_seconds": total_data_load_time,
  433. "llm_time_seconds": total_llm_time,
  434. "llm_post_processing_time_seconds": total_llm_post_processing_time,
  435. "token2wav_time_seconds": total_token2wav_time,
  436. "audio_save_time_seconds": total_audio_save_time,
  437. "total_audio_duration_seconds": total_audio_duration_seconds,
  438. "pipeline_time_seconds": end_time - start_time,
  439. }
  440. print(log_data)
  441. f.write(json.dumps(log_data, indent=4))
  442. print(f"Metrics logged to {log_file_path}")
  443. if __name__ == "__main__":
  444. args = get_args()
  445. if args.backend == "vllm":
  446. from vllm import LLM, SamplingParams
  447. elif args.backend == "trtllm":
  448. import tensorrt_llm
  449. from tensorrt_llm.runtime import ModelRunnerCpp
  450. elif args.backend == "hf":
  451. from transformers import AutoModelForCausalLM
  452. elif args.backend == "trtllm-serve":
  453. pass
  454. else:
  455. raise ValueError(f"Unsupported backend: {args.backend}")
  456. main(args)