infer_dataset.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Example Usage
  16. dataset=zero_shot_zh
  17. output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
  18. token2wav_path=/workspace/CosyVoice2-0.5B
  19. CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
  20. torchrun --nproc_per_node=8 \
  21. infer_dataset.py \
  22. --output-dir $output_dir \
  23. --llm-model-name-or-path $llm_path/merged_hf_model \
  24. --token2wav-path $token2wav_path \
  25. --split-name ${dataset} || exit 1
  26. """
  27. import argparse
  28. import json
  29. import os
  30. import sys
  31. from pathlib import Path
  32. import torch
  33. import torch.distributed as dist
  34. import torch.nn.functional as F
  35. import torchaudio
  36. from cosyvoice.cli.cosyvoice import CosyVoice2
  37. from cosyvoice.utils.file_utils import load_wav
  38. from datasets import load_dataset
  39. from transformers import AutoTokenizer, AutoModelForCausalLM
  40. from torch.utils.data import DataLoader, Dataset, DistributedSampler
  41. from tqdm import tqdm
  42. import soundfile as sf
  43. import s3tokenizer
  44. from functools import partial
  45. sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
  46. try:
  47. torch.multiprocessing.set_start_method("spawn")
  48. except RuntimeError:
  49. pass
  50. TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"
  51. def audio_decode_cosyvoice2(
  52. audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
  53. ):
  54. """
  55. Generate audio from tokens with optional tone and prompt embedding.
  56. """
  57. model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
  58. "empty", prompt_text, prompt_speech_16k, 24000
  59. )
  60. tts_mel, _ = codec_decoder.model.flow.inference(
  61. token=audio_tokens.to(codec_decoder.model.device),
  62. token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
  63. codec_decoder.model.device
  64. ),
  65. prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
  66. codec_decoder.model.device
  67. ),
  68. prompt_token_len=torch.tensor(
  69. [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
  70. ).to(codec_decoder.model.device),
  71. prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
  72. codec_decoder.model.device
  73. ),
  74. prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
  75. codec_decoder.model.device
  76. ),
  77. embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
  78. finalize=True,
  79. )
  80. audio_hat, _ = codec_decoder.model.hift.inference(
  81. speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
  82. )
  83. return audio_hat
  84. def extract_speech_ids(speech_tokens_str):
  85. """Extract speech IDs from token strings like <|s_23456|>"""
  86. speech_ids = []
  87. for token_str in speech_tokens_str:
  88. if token_str.startswith('<|s_') and token_str.endswith('|>'):
  89. num_str = token_str[4:-2]
  90. num = int(num_str)
  91. speech_ids.append(num)
  92. else:
  93. print(f"Unexpected token: {token_str}")
  94. return speech_ids
  95. def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
  96. """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
  97. speech_id_str = ""
  98. for token in cosy2_tokens:
  99. speech_id_str += f"<|s_{token}|>"
  100. return speech_id_str
  101. def get_args():
  102. parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
  103. parser.add_argument(
  104. "--split-name",
  105. type=str,
  106. default="wenetspeech4tts",
  107. help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
  108. )
  109. parser.add_argument(
  110. "--output-dir", required=True, type=str, help="dir to save result"
  111. )
  112. parser.add_argument(
  113. "--batch-size",
  114. default=1,
  115. type=int,
  116. help="batch size (per-device) for inference",
  117. )
  118. parser.add_argument(
  119. "--num-workers", type=int, default=1, help="workers for dataloader"
  120. )
  121. parser.add_argument(
  122. "--prefetch", type=int, default=5, help="prefetch for dataloader"
  123. )
  124. parser.add_argument(
  125. "--llm-model-name-or-path",
  126. required=True,
  127. type=str,
  128. help="LLM model path (includes both model and tokenizer)",
  129. )
  130. parser.add_argument(
  131. "--token2wav-path",
  132. required=True,
  133. type=str,
  134. help="CosyVoice2 token2wav model path",
  135. )
  136. parser.add_argument(
  137. "--prompt-text",
  138. type=str,
  139. default=None,
  140. help="The prompt text for CosyVoice2",
  141. )
  142. parser.add_argument(
  143. "--prompt-speech-path",
  144. type=str,
  145. default=None,
  146. help="The path to the prompt speech for CosyVoice2",
  147. )
  148. parser.add_argument(
  149. "--top-p",
  150. type=float,
  151. default=0.95,
  152. help="top p for sampling",
  153. )
  154. parser.add_argument(
  155. "--temperature",
  156. type=float,
  157. default=0.8,
  158. help="temperature for sampling",
  159. )
  160. parser.add_argument(
  161. "--top-k",
  162. type=int,
  163. default=50,
  164. help="top k for sampling",
  165. )
  166. args = parser.parse_args()
  167. return args
  168. def data_collator(batch, tokenizer, s3_tokenizer):
  169. """Simplified data collator for batch_size=1 processing"""
  170. target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
  171. device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
  172. input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
  173. mels, prompt_audio_cosy2tokens_list = [], []
  174. for item in batch:
  175. prompt_text, target_text = (
  176. item["prompt_text"],
  177. item["target_text"],
  178. )
  179. prompt_text_list.append(prompt_text)
  180. # Combine prompt and target text
  181. full_text = prompt_text + target_text
  182. # get prompt audio for CosyVoice2 (convert to 16kHz)
  183. ref_audio_org, ref_sr = (
  184. item["prompt_audio"]["array"],
  185. item["prompt_audio"]["sampling_rate"],
  186. )
  187. ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
  188. # ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
  189. print(ref_audio_org.shape)
  190. if ref_sr != target_sample_rate:
  191. resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
  192. ref_audio = resampler(ref_audio_org)
  193. else:
  194. ref_audio = ref_audio_org
  195. prompt_audio_list.append(ref_audio)
  196. if "prompt_audio_cosy2_tokens" in item:
  197. prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
  198. prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
  199. else:
  200. # convert to float first
  201. mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
  202. if len(mels) > 0:
  203. mels, mels_lens = s3tokenizer.padding(mels)
  204. codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
  205. for i in range(len(codes)):
  206. prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
  207. for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
  208. prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
  209. # Create chat template for LLM generation
  210. chat = [
  211. {"role": "user", "content": full_text},
  212. {"role": "assistant", "content": prompt_audio_cosy2_id_str}
  213. ]
  214. if 'system' in tokenizer.chat_template:
  215. tokenizer.chat_template = TEMPLATE
  216. input_ids = tokenizer.apply_chat_template(
  217. chat,
  218. tokenize=True,
  219. return_tensors='pt',
  220. continue_final_message=True
  221. )
  222. input_ids_list.append(input_ids.squeeze(0))
  223. # For batch_size=1, no need to pad
  224. if len(input_ids_list) == 1:
  225. input_ids = input_ids_list[0].unsqueeze(0)
  226. else:
  227. # Handle batch > 1 if needed
  228. max_len = max([len(input_ids) for input_ids in input_ids_list])
  229. input_ids_list = [
  230. torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
  231. for input_ids in input_ids_list
  232. ]
  233. input_ids = torch.stack(input_ids_list)
  234. ids = [item["id"] for item in batch]
  235. return {
  236. "input_ids": input_ids,
  237. "ids": ids,
  238. "prompt_text": prompt_text_list,
  239. "prompt_audio_list": prompt_audio_list,
  240. }
  241. def init_distributed():
  242. world_size = int(os.environ.get("WORLD_SIZE", 1))
  243. local_rank = int(os.environ.get("LOCAL_RANK", 0))
  244. rank = int(os.environ.get("RANK", 0))
  245. print(
  246. "Inference on multiple gpus, this gpu {}".format(local_rank)
  247. + ", rank {}, world_size {}".format(rank, world_size)
  248. )
  249. torch.cuda.set_device(local_rank)
  250. dist.init_process_group("nccl")
  251. return world_size, local_rank, rank
  252. def main():
  253. args = get_args()
  254. os.makedirs(args.output_dir, exist_ok=True)
  255. assert torch.cuda.is_available()
  256. world_size, local_rank, rank = init_distributed()
  257. device = torch.device(f"cuda:{local_rank}")
  258. # Load LLM model and tokenizer directly
  259. tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
  260. model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
  261. model.eval()
  262. model.to(device)
  263. cosyvoice_codec = CosyVoice2(
  264. args.token2wav_path, load_jit=True, load_trt=True, fp16=True
  265. )
  266. if args.prompt_speech_path:
  267. prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
  268. else:
  269. prompt_speech_16k = None
  270. s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
  271. dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
  272. dataset = load_dataset(
  273. dataset_name,
  274. split=args.split_name,
  275. trust_remote_code=True,
  276. )
  277. sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
  278. dataloader = DataLoader(
  279. dataset,
  280. batch_size=args.batch_size,
  281. sampler=sampler,
  282. shuffle=False,
  283. num_workers=args.num_workers,
  284. prefetch_factor=args.prefetch,
  285. collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
  286. )
  287. total_steps = len(dataset)
  288. if rank == 0:
  289. progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
  290. for batch in dataloader:
  291. with torch.no_grad():
  292. input_ids = batch["input_ids"].to(device)
  293. # Generate speech tokens using LLM
  294. outputs = model.generate(
  295. input_ids,
  296. max_new_tokens=2048, # Max length for generation
  297. do_sample=True,
  298. top_p=args.top_p,
  299. temperature=args.temperature,
  300. top_k=args.top_k,
  301. )
  302. # Process each sample in the batch
  303. for i in range(len(batch["ids"])):
  304. # Extract generated tokens (excluding input)
  305. input_length = input_ids[i].shape[0]
  306. generated_ids = outputs[i][input_length:-1] # Remove last token if needed
  307. speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  308. # Extract speech IDs from token strings like <|s_23456|>
  309. speech_ids = extract_speech_ids(speech_tokens_str)
  310. if len(speech_ids) == 0:
  311. print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
  312. continue
  313. # Convert to tensor for CosyVoice2
  314. audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
  315. if args.prompt_text is not None:
  316. current_prompt_text = args.prompt_text
  317. current_prompt_audio = prompt_speech_16k
  318. else:
  319. current_prompt_text = batch["prompt_text"][i]
  320. current_prompt_audio = batch["prompt_audio_list"][i]
  321. if current_prompt_audio is not None:
  322. # Generate audio using CosyVoice2
  323. audio_hat = audio_decode_cosyvoice2(
  324. audio_tokens,
  325. current_prompt_text,
  326. current_prompt_audio,
  327. cosyvoice_codec,
  328. )
  329. # Convert to numpy and save
  330. generated_wave = audio_hat.squeeze(0).cpu().numpy()
  331. target_sample_rate = 24000
  332. utt = batch["ids"][i]
  333. sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
  334. print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
  335. else:
  336. print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
  337. if rank == 0:
  338. progress_bar.update(world_size * len(batch["ids"]))
  339. if rank == 0:
  340. progress_bar.close()
  341. dist.barrier()
  342. dist.destroy_process_group()
  343. if __name__ == "__main__":
  344. main()