123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563 |
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
- # SPDX-License-Identifier: Apache-2.0
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ Example Usage
- CUDA_VISIBLE_DEVICES=0 \
- python3 offline_inference.py \
- --output-dir $output_dir \
- --llm-model-name-or-path $huggingface_model_local_dir \
- --token2wav-path $model_scope_model_local_dir \
- --backend $backend \
- --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
- --engine-dir $trt_engines_dir \
- --split-name ${dataset} || exit 1
- """
- import argparse
- import json
- import os
- import sys
- from pathlib import Path
- import torch
- import torch.distributed as dist
- import torch.nn.functional as F
- import torchaudio
- from cosyvoice.utils.file_utils import load_wav
- from datasets import load_dataset
- from transformers import AutoTokenizer
- from torch.utils.data import DataLoader, Dataset
- from tqdm import tqdm
- import soundfile as sf
- import s3tokenizer
- from functools import partial
- import time
- from token2wav import CosyVoice2_Token2Wav
- sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
- try:
- torch.multiprocessing.set_start_method("spawn")
- except RuntimeError:
- pass
- def extract_speech_ids(speech_tokens_str):
- """Extract speech IDs from token strings like <|s_23456|>"""
- speech_ids = []
- for token_str in speech_tokens_str:
- if token_str.startswith('<|s_') and token_str.endswith('|>'):
- num_str = token_str[4:-2]
- num = int(num_str)
- speech_ids.append(num)
- else:
- print(f"Unexpected token: {token_str}")
- return speech_ids
- def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
- """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
- speech_id_str = ""
- for token in cosy2_tokens:
- speech_id_str += f"<|s_{token}|>"
- return speech_id_str
- def get_args():
- parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
- parser.add_argument(
- "--split-name",
- type=str,
- default="wenetspeech4tts",
- help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
- )
- parser.add_argument(
- "--output-dir", required=True, type=str, help="dir to save result"
- )
- parser.add_argument(
- "--batch-size",
- default=1,
- type=int,
- help="batch size (per-device) for inference",
- )
- parser.add_argument(
- "--token2wav-batch-size",
- default=1,
- type=int,
- help="batch size (per-device) for inference",
- )
- parser.add_argument(
- "--num-workers", type=int, default=0, help="workers for dataloader"
- )
- parser.add_argument(
- "--prefetch", type=int, default=None, help="prefetch for dataloader"
- )
- parser.add_argument(
- "--llm-model-name-or-path",
- required=True,
- type=str,
- help="LLM model path (includes both model and tokenizer)",
- )
- parser.add_argument(
- "--token2wav-path",
- required=True,
- type=str,
- help="CosyVoice2 token2wav model path",
- )
- parser.add_argument(
- "--prompt-text",
- type=str,
- default=None,
- help="The prompt text for CosyVoice2",
- )
- parser.add_argument(
- "--prompt-speech-path",
- type=str,
- default=None,
- help="The path to the prompt speech for CosyVoice2",
- )
- parser.add_argument(
- "--top-p",
- type=float,
- default=0.95,
- help="top p for sampling",
- )
- parser.add_argument(
- "--temperature",
- type=float,
- default=0.8,
- help="temperature for sampling",
- )
- parser.add_argument(
- "--top-k",
- type=int,
- default=50,
- help="top k for sampling",
- )
- parser.add_argument(
- "--backend",
- type=str,
- default="hf",
- choices=["hf", "trtllm", "vllm"],
- help="Backend to use for LLM inference: 'hf' for HuggingFace, 'trtllm' for TensorRT-LLM, 'vllm' for VLLM",
- )
- parser.add_argument(
- "--engine-dir",
- type=str,
- default=None,
- help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
- )
- parser.add_argument(
- "--kv-cache-free-gpu-memory-fraction",
- type=float,
- default=0.6,
- help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
- )
- args = parser.parse_args()
- return args
- def data_collator(batch, tokenizer, s3_tokenizer):
- """Simplified data collator for batch_size=1 processing"""
- collator_start_time = time.time()
- total_audio_processing_time = 0
- total_speech_tokenization_time = 0
- total_text_tokenization_time = 0
- target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
- device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
- input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
- prompt_text_after_apply_template_list = []
- mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
- for _, item in enumerate(batch):
- audio_processing_start_time = time.time()
- prompt_text, target_text = (
- item["prompt_text"],
- item["target_text"],
- )
- prompt_text_list.append(prompt_text)
- full_text = prompt_text + target_text
- full_text_list.append(full_text)
- # remove the unnecessary punctuation for cosyvoice3 zero_shot_zh dataset
- puncts = ['"', '(', ')', '“', '”', '‘', '(', ')', '\'']
- for p in puncts:
- if p in full_text:
- full_text = full_text.replace(p, '')
- print(f"removed {p} from {full_text}")
- # get prompt audio for CosyVoice2 (convert to 16kHz)
- ref_audio_org, ref_sr = (
- item["prompt_audio"]["array"],
- item["prompt_audio"]["sampling_rate"],
- )
- ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
- print(ref_audio_org.shape)
- if ref_sr != target_sample_rate:
- resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
- ref_audio = resampler(ref_audio_org)
- else:
- ref_audio = ref_audio_org
- prompt_audio_list.append(ref_audio)
- audio_processing_end_time = time.time()
- total_audio_processing_time += audio_processing_end_time - audio_processing_start_time
- speech_tokenization_start_time = time.time()
- if "prompt_audio_cosy2_tokens" in item:
- prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
- prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
- else:
- mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
- if len(mels) > 0:
- mels, mels_lens = s3tokenizer.padding(mels)
- codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
- for i in range(len(codes)):
- prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
- speech_tokenization_end_time = time.time()
- total_speech_tokenization_time += speech_tokenization_end_time - speech_tokenization_start_time
- for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
- text_tokenization_start_time = time.time()
- prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
- # Create chat template for LLM generation
- chat = [
- {"role": "user", "content": full_text_list[i]},
- {"role": "assistant", "content": prompt_audio_cosy2_id_str}
- ]
- assert 'system' not in tokenizer.chat_template, "system is not allowed in the chat template"
- input_ids = tokenizer.apply_chat_template(
- chat,
- tokenize=True,
- return_tensors='pt',
- continue_final_message=True
- )
- input_ids_list.append(input_ids.squeeze(0))
- prompt_text_after_apply_template = f"<|sos|>{full_text_list[i]}<|task_id|>{prompt_audio_cosy2_id_str}"
- prompt_text_after_apply_template_list.append(prompt_text_after_apply_template)
- text_tokenization_end_time = time.time()
- total_text_tokenization_time += text_tokenization_end_time - text_tokenization_start_time
- ids = [item["id"] for item in batch]
- return {
- "input_ids": input_ids_list,
- "ids": ids,
- "prompt_text": prompt_text_list,
- "prompt_audio_list": prompt_audio_list,
- "prompt_text_after_apply_template": prompt_text_after_apply_template_list,
- "audio_processing_time": total_audio_processing_time,
- "speech_tokenization_time": total_speech_tokenization_time,
- "text_tokenization_time": total_text_tokenization_time,
- }
- def init_distributed():
- world_size = int(os.environ.get("WORLD_SIZE", 1))
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
- rank = int(os.environ.get("RANK", 0))
- print(
- "Inference on multiple gpus, this gpu {}".format(local_rank)
- + ", rank {}, world_size {}".format(rank, world_size)
- )
- torch.cuda.set_device(local_rank)
- dist.init_process_group("nccl")
- return world_size, local_rank, rank
- def main(args):
- os.makedirs(args.output_dir, exist_ok=True)
- assert torch.cuda.is_available()
- local_rank, world_size, rank = 0, 1, 0
- device = torch.device(f"cuda:{local_rank}")
- tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
- if args.backend == "hf":
- model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
- model.eval()
- model.to(device)
- runner = None
- elif args.backend == "trtllm":
- if args.engine_dir is None:
- raise ValueError("--engine-dir is required when backend is 'trtllm'")
- runtime_rank = tensorrt_llm.mpi_rank()
- model = None
- runner_kwargs = dict(
- engine_dir=args.engine_dir,
- rank=runtime_rank,
- max_output_len=2048,
- enable_context_fmha_fp32_acc=False,
- max_batch_size=args.batch_size,
- max_input_len=512,
- kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
- cuda_graph_mode=False,
- gather_generation_logits=False,
- )
- runner = ModelRunnerCpp.from_dir(**runner_kwargs)
- elif args.backend == "vllm":
- model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
- runner = None
- else:
- raise ValueError(f"Unsupported backend: {args.backend}")
- token2wav_model = CosyVoice2_Token2Wav(
- model_dir=args.token2wav_path, enable_trt=True, device_id=local_rank
- )
- if args.prompt_speech_path:
- prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
- else:
- prompt_speech_16k = None
- s3_tokenizer = s3tokenizer.load_model(f"{args.token2wav_path}/speech_tokenizer_v2.onnx").to(device) if 'zero' in args.split_name else None
- dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
- dataset = load_dataset(
- dataset_name,
- split=args.split_name,
- trust_remote_code=True,
- )
- sampler = None
- dataloader = DataLoader(
- dataset,
- batch_size=args.batch_size,
- sampler=sampler,
- shuffle=False,
- num_workers=args.num_workers,
- prefetch_factor=args.prefetch,
- collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
- )
- for _ in range(3):
- print(f"Running {_} times")
- total_llm_time = 0
- total_token2wav_time = 0
- total_data_load_time = 0
- total_llm_post_processing_time = 0
- total_audio_save_time = 0
- total_audio_processing_time_in_collator = 0
- total_speech_tokenization_time_in_collator = 0
- total_text_tokenization_time_in_collator = 0
- total_audio_samples = 0
- start_time = time.time()
- total_steps = len(dataset)
- if rank == 0:
- progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
- last_batch_end_time = time.time()
- for batch in dataloader:
- data_loaded_time = time.time()
- total_data_load_time += data_loaded_time - last_batch_end_time
- total_audio_processing_time_in_collator += batch["audio_processing_time"]
- total_speech_tokenization_time_in_collator += batch["speech_tokenization_time"]
- total_text_tokenization_time_in_collator += batch["text_tokenization_time"]
- with torch.no_grad():
- llm_start_time = time.time()
- if args.backend == "hf":
- input_ids_list = batch["input_ids"]
- if len(input_ids_list) == 1:
- input_ids = input_ids_list[0].unsqueeze(0)
- attention_mask = torch.ones_like(input_ids)
- else:
- max_len = max([len(input_ids) for input_ids in input_ids_list])
- input_ids_list_new = [
- torch.cat([input_ids, torch.full((max_len - len(input_ids),), tokenizer.pad_token_id)])
- for input_ids in input_ids_list
- ]
- input_ids = torch.stack(input_ids_list_new)
- attention_mask = torch.zeros_like(input_ids)
- for i in range(len(input_ids_list)):
- attention_mask[i, :len(input_ids_list[i])] = 1
- input_ids = input_ids.to(device)
- outputs = model.generate(
- input_ids=input_ids.to(device),
- attention_mask=attention_mask.to(device),
- max_new_tokens=2048,
- do_sample=True,
- top_p=args.top_p,
- temperature=args.temperature,
- repetition_penalty=1.1,
- top_k=args.top_k,
- )
- torch.cuda.synchronize()
- elif args.backend == "trtllm":
- batch_input_ids = list(batch["input_ids"])
- input_lengths = [x.size(0) for x in batch_input_ids]
- end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
- print(f"end_id: {end_id}, tokenizer.eos_token_id: {tokenizer.eos_token_id} ========================")
- outputs = runner.generate(
- batch_input_ids=batch_input_ids,
- max_new_tokens=2048,
- end_id=end_id,
- pad_id=end_id,
- temperature=args.temperature,
- top_k=args.top_k,
- top_p=args.top_p,
- repetition_penalty=1.1,
- num_return_sequences=1,
- streaming=False,
- output_sequence_lengths=True,
- output_generation_logits=False,
- return_dict=True,
- return_all_generated_tokens=False
- )
- torch.cuda.synchronize()
- output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
- num_output_sents, num_beams, _ = output_ids.size()
- assert num_beams == 1
- beam = 0
- batch_size = len(batch["input_ids"])
- num_return_sequences = num_output_sents // batch_size
- assert num_return_sequences == 1
- outputs = []
- for i in range(batch_size * num_return_sequences):
- batch_idx = i // num_return_sequences
- seq_idx = i % num_return_sequences
- output_begin = input_lengths[batch_idx]
- output_end = sequence_lengths[i][beam]
- outputs_i = output_ids[i][beam][:output_end].tolist()
- outputs.append(outputs_i)
- elif args.backend == "vllm":
- input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
- sampling_params = SamplingParams(
- temperature=args.temperature,
- top_p=args.top_p,
- top_k=args.top_k,
- repetition_penalty=1.1,
- max_tokens=2048,
- )
- outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
- print(outputs)
- for j, output in enumerate(outputs):
- outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
- llm_end_time = time.time()
- total_llm_time += (llm_end_time - llm_start_time)
- items_for_token_2wav = []
- for i in range(len(batch["ids"])):
- llm_post_processing_start_time = time.time()
- input_length = len(batch["input_ids"][i])
- generated_ids = outputs[i][input_length:]
- speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
- speech_ids = extract_speech_ids(speech_tokens_str)
- print(i, speech_ids)
- if len(speech_ids) == 0:
- print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
- continue
- if args.prompt_text is not None:
- current_prompt_text = args.prompt_text
- current_prompt_audio = prompt_speech_16k
- else:
- current_prompt_text = batch["prompt_text"][i]
- current_prompt_audio = batch["prompt_audio_list"][i]
- llm_post_processing_end_time = time.time()
- total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
- if current_prompt_audio is not None:
- items_for_token_2wav.append({
- "speech_ids": speech_ids,
- "prompt_audio": current_prompt_audio.squeeze(0),
- "id": batch["ids"][i]
- })
- else:
- print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
- for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
- t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
- if not t2w_batch:
- continue
- t2w_generated_speech_tokens_list = [item["speech_ids"] for item in t2w_batch]
- t2w_prompt_audios_list = [item["prompt_audio"] for item in t2w_batch]
- t2w_prompt_audios_sample_rate = [16000] * len(t2w_batch)
- t2w_ids = [item["id"] for item in t2w_batch]
- token2wav_start_time = time.time()
- generated_wavs = token2wav_model(
- t2w_generated_speech_tokens_list,
- t2w_prompt_audios_list,
- t2w_prompt_audios_sample_rate,
- )
- torch.cuda.synchronize()
- token2wav_end_time = time.time()
- total_token2wav_time += (token2wav_end_time - token2wav_start_time)
- audio_save_start_time = time.time()
- for j, audio_hat in enumerate(generated_wavs):
- generated_wave = audio_hat.squeeze().cpu().numpy()
- total_audio_samples += len(generated_wave)
- target_sample_rate = 24000
- utt = t2w_ids[j]
- sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
- print(f"Generated audio for sample {utt} with {len(t2w_generated_speech_tokens_list[j])} tokens")
- audio_save_end_time = time.time()
- total_audio_save_time += audio_save_end_time - audio_save_start_time
- if rank == 0:
- progress_bar.update(world_size * len(batch["ids"]))
- last_batch_end_time = time.time()
- if rank == 0:
- progress_bar.close()
- end_time = time.time()
- target_sample_rate = 24000
- total_audio_duration_seconds = total_audio_samples / target_sample_rate
- log_file_path = os.path.join(args.output_dir, "log.txt")
- with open(log_file_path, 'w') as f:
- args_dict = vars(args)
- log_data = {
- "args": args_dict,
- "data_load_time_seconds": total_data_load_time,
- "audio_processing_time_in_collator_seconds": total_audio_processing_time_in_collator,
- "speech_tokenization_time_in_collator_seconds": total_speech_tokenization_time_in_collator,
- "text_tokenization_time_in_collator_seconds": total_text_tokenization_time_in_collator,
- "llm_time_seconds": total_llm_time,
- "llm_post_processing_time_seconds": total_llm_post_processing_time,
- "token2wav_time_seconds": total_token2wav_time,
- "audio_save_time_seconds": total_audio_save_time,
- "total_audio_duration_seconds": total_audio_duration_seconds,
- "pipeline_time_seconds": end_time - start_time,
- }
- print(log_data)
- f.write(json.dumps(log_data, indent=4))
- print(f"Metrics logged to {log_file_path}")
- if __name__ == "__main__":
- args = get_args()
- if args.backend == "vllm":
- from vllm import LLM, SamplingParams
- elif args.backend == "trtllm":
- import tensorrt_llm
- from tensorrt_llm.runtime import ModelRunnerCpp
- elif args.backend == "hf":
- from transformers import AutoModelForCausalLM
- else:
- raise ValueError(f"Unsupported backend: {args.backend}")
- main(args)
|