| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512 |
- """ Example Usage
- CUDA_VISIBLE_DEVICES=0 \
- python3 infer_cosyvoice3_token2wav.py \
- --output-dir $output_dir \
- --llm-model-name-or-path $huggingface_model_local_dir \
- --token2wav-path $token2wav_model_dir \
- --backend $backend \
- --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \
- --engine-dir $trt_engines_dir \
- --split-name ${dataset} || exit 1
- """
- import argparse
- import json
- import os
- import time
- import asyncio
- import torch
- import torchaudio
- import s3tokenizer
- import soundfile as sf
- import requests
- import httpx
- from transformers import AutoTokenizer
- from datasets import load_dataset
- from torch.utils.data import DataLoader
- from functools import partial
- from tqdm import tqdm
- from token2wav_cosyvoice3 import CosyVoice3_Token2Wav
- try:
- torch.multiprocessing.set_start_method("spawn")
- except RuntimeError:
- pass
- async def send_request_async(client, url, payload):
- response = await client.post(url, json=payload, timeout=None)
- response.raise_for_status()
- response_json = response.json()
- return response_json['choices'][0]['message']['content']
- async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k):
- async with httpx.AsyncClient() as client:
- tasks = []
- for chat in chats:
- payload = {
- "model": model_name,
- "messages": chat,
- "max_tokens": 2048,
- "temperature": temperature,
- "top_p": top_p,
- "top_k": top_k,
- "repetition_penalty": 1.1,
- "stop": ["<|eos1|>", "<|eos|>"],
- "stream": False,
- }
- tasks.append(send_request_async(client, api_base, payload))
- return await asyncio.gather(*tasks)
- def extract_speech_ids(speech_tokens_str):
- """Extract speech IDs from token strings like <|s_23456|>"""
- speech_ids = []
- for token_str in speech_tokens_str:
- if token_str.startswith('<|s_') and token_str.endswith('|>'):
- num_str = token_str[4:-2]
- num = int(num_str)
- speech_ids.append(num)
- else:
- print(f"Unexpected token: {token_str}")
- return speech_ids
- def convert_cosy3_tokens_to_speech_id_str(cosy3_tokens):
- """Convert CosyVoice3 tokens to speech IDs string like <|s_23456|>"""
- if hasattr(cosy3_tokens, 'cpu'):
- cosy3_tokens = cosy3_tokens.cpu().numpy().tolist()
- speech_id_str = ""
- for token in cosy3_tokens:
- speech_id_str += f"<|s_{token}|>"
- return speech_id_str
- def get_args():
- parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice3")
- parser.add_argument(
- "--split-name", type=str, default="wenetspeech4tts",
- help="huggingface dataset split name",
- )
- parser.add_argument(
- "--output-dir", required=True, type=str, help="dir to save result",
- )
- parser.add_argument(
- "--batch-size", default=1, type=int,
- help="batch size (per-device) for LLM inference",
- )
- parser.add_argument(
- "--token2wav-batch-size", default=1, type=int,
- help="batch size (per-device) for token2wav inference",
- )
- parser.add_argument(
- "--num-workers", type=int, default=0, help="workers for dataloader",
- )
- parser.add_argument(
- "--prefetch", type=int, default=None, help="prefetch for dataloader",
- )
- parser.add_argument(
- "--llm-model-name-or-path", required=True, type=str,
- help="CosyVoice3 HF LLM path (e.g. ./hf_cosyvoice3_llm)",
- )
- parser.add_argument(
- "--token2wav-path", required=True, type=str,
- help="CosyVoice3 model path (e.g. /workspace_yuekai/HF/Fun-CosyVoice3-0.5B-2512)",
- )
- parser.add_argument(
- "--enable-trt", action="store_true",
- help="Enable TensorRT for flow decoder estimator",
- )
- parser.add_argument(
- "--streaming", action="store_true",
- help="Enable streaming for flow decoder estimator",
- )
- parser.add_argument(
- "--top-p", type=float, default=0.95, help="top p for sampling",
- )
- parser.add_argument(
- "--temperature", type=float, default=0.8, help="temperature for sampling",
- )
- parser.add_argument(
- "--top-k", type=int, default=15, help="top k for sampling",
- )
- parser.add_argument(
- "--backend", type=str, default="hf",
- choices=["hf", "trtllm", "vllm", "trtllm-serve"],
- help="Backend to use for LLM inference",
- )
- parser.add_argument(
- "--engine-dir", type=str, default=None,
- help="TensorRT-LLM engine directory (required when backend is 'trtllm')",
- )
- parser.add_argument(
- "--kv-cache-free-gpu-memory-fraction", type=float, default=0.6,
- help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)",
- )
- parser.add_argument(
- "--openai-api-base", type=str,
- default="http://localhost:8000/v1/chat/completions",
- help="OpenAI API base URL (for trtllm-serve backend)",
- )
- parser.add_argument(
- "--openai-model-name", type=str, default="trt_engines_bfloat16",
- help="Model name to use with OpenAI API (for trtllm-serve backend)",
- )
- parser.add_argument(
- "--epoch", type=int, default=1, help="Epoch to run",
- )
- return parser.parse_args()
- def data_collator(batch, tokenizer, s3_tokenizer):
- """Data collator: extracts cosy3 tokens from prompt_audio using v3 s3 tokenizer."""
- device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
- target_sample_rate = 16000
- input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
- mels, prompt_audio_cosy3tokens_list, full_text_list = [], [], []
- chat_list = []
- for item in batch:
- prompt_text, target_text = item["prompt_text"], item["target_text"]
- prompt_text_list.append(prompt_text)
- full_text = 'You are a helpful assistant.<|endofprompt|>' + prompt_text + target_text
- full_text_list.append(full_text)
- # Get prompt audio (convert to 16kHz for s3 tokenizer)
- ref_audio = torch.from_numpy(item["prompt_audio"]["array"]).float().unsqueeze(0)
- ref_sr = item["prompt_audio"]["sampling_rate"]
- if ref_sr != target_sample_rate:
- ref_audio = torchaudio.transforms.Resample(ref_sr, target_sample_rate)(ref_audio)
- prompt_audio_list.append(ref_audio)
- # Extract cosy3 tokens from prompt_audio using v3 s3 tokenizer
- mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
- # Batch tokenization with v3 tokenizer
- if len(mels) > 0:
- mels_padded, mels_lens = s3tokenizer.padding(mels)
- codes, codes_lens = s3_tokenizer.quantize(mels_padded.to(device), mels_lens.to(device))
- for i in range(len(codes)):
- prompt_audio_cosy3tokens_list.append(codes[i, :codes_lens[i].item()])
- # Build LLM inputs
- for i, prompt_audio_cosy3tokens in enumerate(prompt_audio_cosy3tokens_list):
- prompt_audio_cosy3_id_str = convert_cosy3_tokens_to_speech_id_str(
- prompt_audio_cosy3tokens)
- chat = [
- {"role": "user", "content": full_text_list[i]},
- {"role": "assistant", "content": prompt_audio_cosy3_id_str}
- ]
- chat_list.append(chat)
- input_ids = tokenizer.apply_chat_template(
- chat, tokenize=True, return_tensors='pt', continue_final_message=True)
- input_ids_list.append(input_ids.squeeze(0))
- ids = [item["id"] for item in batch]
- return {
- "input_ids": input_ids_list,
- "ids": ids,
- "prompt_text": prompt_text_list,
- "prompt_audio_list": prompt_audio_list,
- "chat_list": chat_list,
- }
- def main(args):
- os.makedirs(args.output_dir, exist_ok=True)
- assert torch.cuda.is_available()
- local_rank = 0
- device = torch.device(f"cuda:{local_rank}")
- tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
- if args.backend == "hf":
- model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
- model.eval()
- model.to(device)
- runner = None
- elif args.backend == "trtllm":
- if args.engine_dir is None:
- raise ValueError("--engine-dir is required when backend is 'trtllm'")
- runtime_rank = tensorrt_llm.mpi_rank()
- model = None
- runner_kwargs = dict(
- engine_dir=args.engine_dir,
- rank=runtime_rank,
- max_output_len=2048,
- enable_context_fmha_fp32_acc=False,
- max_batch_size=args.batch_size,
- max_input_len=512,
- kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction,
- cuda_graph_mode=False,
- gather_generation_logits=False,
- )
- runner = ModelRunnerCpp.from_dir(**runner_kwargs)
- elif args.backend == "vllm":
- model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4)
- runner = None
- elif args.backend == "trtllm-serve":
- model = None
- runner = None
- else:
- raise ValueError(f"Unsupported backend: {args.backend}")
- token2wav_model = CosyVoice3_Token2Wav(
- model_dir=args.token2wav_path, enable_trt=args.enable_trt, device_id=local_rank, streaming=args.streaming
- )
- # Load v3 s3 tokenizer for prompt audio tokenization in data_collator
- s3_tokenizer = s3tokenizer.load_model(
- f"{args.token2wav_path}/speech_tokenizer_v3.onnx"
- ).to(device).eval()
- dataset = load_dataset(
- "yuekai/seed_tts_cosy2",
- split=args.split_name,
- trust_remote_code=True,
- )
- dataloader = DataLoader(
- dataset,
- batch_size=args.batch_size,
- shuffle=False,
- num_workers=args.num_workers,
- prefetch_factor=args.prefetch,
- collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
- )
- for epoch in range(args.epoch):
- print(f"Running epoch {epoch}")
- total_llm_time = 0
- total_token2wav_time = 0
- total_data_load_time = 0
- total_llm_post_processing_time = 0
- total_audio_save_time = 0
- total_audio_samples = 0
- start_time = time.time()
- progress_bar = tqdm(total=len(dataset), desc="Processing", unit="wavs")
- last_batch_end_time = time.time()
- for batch in dataloader:
- data_loaded_time = time.time()
- total_data_load_time += data_loaded_time - last_batch_end_time
- with torch.no_grad():
- llm_start_time = time.time()
- if args.backend == "hf":
- input_ids_list = batch["input_ids"]
- if len(input_ids_list) == 1:
- input_ids = input_ids_list[0].unsqueeze(0)
- attention_mask = torch.ones_like(input_ids)
- else:
- max_len = max([len(ids) for ids in input_ids_list])
- input_ids_list_new = [
- torch.cat([ids, torch.full((max_len - len(ids),), tokenizer.pad_token_id)])
- for ids in input_ids_list
- ]
- input_ids = torch.stack(input_ids_list_new)
- attention_mask = torch.zeros_like(input_ids)
- for i in range(len(input_ids_list)):
- attention_mask[i, :len(input_ids_list[i])] = 1
- outputs = model.generate(
- input_ids=input_ids.to(device),
- attention_mask=attention_mask.to(device),
- max_new_tokens=2048,
- do_sample=True,
- top_p=args.top_p,
- temperature=args.temperature,
- repetition_penalty=1.1,
- top_k=args.top_k,
- )
- torch.cuda.synchronize()
- elif args.backend == "trtllm":
- batch_input_ids = list(batch["input_ids"])
- input_lengths = [x.size(0) for x in batch_input_ids]
- end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
- outputs = runner.generate(
- batch_input_ids=batch_input_ids,
- max_new_tokens=2048,
- end_id=end_id,
- pad_id=end_id,
- temperature=args.temperature,
- top_k=args.top_k,
- top_p=args.top_p,
- repetition_penalty=1.1,
- num_return_sequences=1,
- streaming=False,
- output_sequence_lengths=True,
- output_generation_logits=False,
- return_dict=True,
- return_all_generated_tokens=False
- )
- torch.cuda.synchronize()
- output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"]
- num_output_sents, num_beams, _ = output_ids.size()
- assert num_beams == 1
- batch_size = len(batch["input_ids"])
- num_return_sequences = num_output_sents // batch_size
- assert num_return_sequences == 1
- outputs = []
- for i in range(batch_size * num_return_sequences):
- batch_idx = i // num_return_sequences
- output_begin = input_lengths[batch_idx]
- output_end = sequence_lengths[i][0]
- outputs_i = output_ids[i][0][:output_end].tolist()
- outputs.append(outputs_i)
- elif args.backend == "vllm":
- input_ids_list = [ids.tolist() for ids in batch["input_ids"]]
- sampling_params = SamplingParams(
- temperature=args.temperature,
- top_p=args.top_p,
- top_k=args.top_k,
- repetition_penalty=1.1,
- max_tokens=2048,
- )
- outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params)
- for j, output in enumerate(outputs):
- outputs[j] = input_ids_list[j] + output.outputs[0].token_ids
- elif args.backend == "trtllm-serve":
- if args.batch_size > 1:
- outputs = asyncio.run(send_batch_requests_async(
- args.openai_api_base,
- args.openai_model_name,
- batch["chat_list"],
- args.temperature,
- args.top_p,
- args.top_k,
- ))
- else:
- outputs = []
- for chat in batch["chat_list"]:
- payload = {
- "model": args.openai_model_name,
- "messages": chat,
- "max_tokens": 2048,
- "temperature": args.temperature,
- "top_p": args.top_p,
- "top_k": args.top_k,
- "repetition_penalty": 1.1,
- "stop": ["<|eos1|>", "<|eos|>"],
- "stream": False,
- }
- response = requests.post(args.openai_api_base, json=payload)
- response.raise_for_status()
- response_json = response.json()
- generated_content = response_json['choices'][0]['message']['content']
- outputs.append(generated_content)
- llm_end_time = time.time()
- total_llm_time += (llm_end_time - llm_start_time)
- items_for_token_2wav = []
- for i in range(len(batch["ids"])):
- llm_post_processing_start_time = time.time()
- if args.backend == "trtllm-serve":
- speech_tokens_str = outputs[i].strip().split('><')
- if len(speech_tokens_str) > 1:
- speech_tokens_str = [
- t if t.startswith('<') else '<' + t for t in speech_tokens_str
- ]
- speech_tokens_str = [
- t if t.endswith('>') else t + '>' for t in speech_tokens_str
- ]
- speech_ids = extract_speech_ids(speech_tokens_str)
- else:
- input_length = len(batch["input_ids"][i])
- generated_ids = outputs[i][input_length:]
- speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
- speech_ids = extract_speech_ids(speech_tokens_str)
- print(i, speech_ids[:10], "...", f"total={len(speech_ids)}")
- if len(speech_ids) == 0:
- print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
- llm_post_processing_end_time = time.time()
- total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
- continue
- current_prompt_audio = batch["prompt_audio_list"][i]
- llm_post_processing_end_time = time.time()
- total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time
- items_for_token_2wav.append({
- "speech_ids": speech_ids,
- "prompt_audio": current_prompt_audio.squeeze(0),
- "id": batch["ids"][i]
- })
- for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size):
- t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size]
- if not t2w_batch:
- continue
- t2w_speech_tokens = [item["speech_ids"] for item in t2w_batch]
- t2w_prompt_audios = [item["prompt_audio"] for item in t2w_batch]
- t2w_sample_rates = [16000] * len(t2w_batch)
- token2wav_start_time = time.time()
- generated_wavs = token2wav_model(
- t2w_speech_tokens, t2w_prompt_audios, t2w_sample_rates,
- streaming=args.streaming,
- )
- token2wav_end_time = time.time()
- total_token2wav_time += (token2wav_end_time - token2wav_start_time)
- audio_save_start_time = time.time()
- for j, audio_hat in enumerate(generated_wavs):
- wav = audio_hat.squeeze().cpu().numpy()
- total_audio_samples += len(wav)
- sf.write(f"{args.output_dir}/{t2w_batch[j]['id']}.wav", wav, 24000)
- print(f"Generated audio for sample {t2w_batch[j]['id']} with {len(t2w_speech_tokens[j])} tokens")
- audio_save_end_time = time.time()
- total_audio_save_time += audio_save_end_time - audio_save_start_time
- progress_bar.update(len(batch["ids"]))
- last_batch_end_time = time.time()
- progress_bar.close()
- end_time = time.time()
- total_audio_duration_seconds = total_audio_samples / 24000
- log_file_path = os.path.join(args.output_dir, "log.txt")
- with open(log_file_path, 'w') as f:
- log_data = {
- "args": vars(args),
- "data_load_time_seconds": total_data_load_time,
- "llm_time_seconds": total_llm_time,
- "llm_post_processing_time_seconds": total_llm_post_processing_time,
- "token2wav_time_seconds": total_token2wav_time,
- "audio_save_time_seconds": total_audio_save_time,
- "total_audio_duration_seconds": total_audio_duration_seconds,
- "pipeline_time_seconds": end_time - start_time,
- }
- print(log_data)
- f.write(json.dumps(log_data, indent=4))
- print(f"Metrics logged to {log_file_path}")
- if __name__ == "__main__":
- args = get_args()
- if args.backend == "vllm":
- from vllm import LLM, SamplingParams
- elif args.backend == "trtllm":
- import tensorrt_llm
- from tensorrt_llm.runtime import ModelRunnerCpp
- elif args.backend == "hf":
- from transformers import AutoModelForCausalLM
- elif args.backend == "trtllm-serve":
- pass
- else:
- raise ValueError(f"Unsupported backend: {args.backend}")
- main(args)
|