""" Example Usage CUDA_VISIBLE_DEVICES=0 \ python3 infer_cosyvoice3_token2wav.py \ --output-dir $output_dir \ --llm-model-name-or-path $huggingface_model_local_dir \ --token2wav-path $token2wav_model_dir \ --backend $backend \ --batch-size $batch_size --token2wav-batch-size $token2wav_batch_size \ --engine-dir $trt_engines_dir \ --split-name ${dataset} || exit 1 """ import argparse import json import os import time import asyncio import torch import torchaudio import s3tokenizer import soundfile as sf import requests import httpx from transformers import AutoTokenizer from datasets import load_dataset from torch.utils.data import DataLoader from functools import partial from tqdm import tqdm from token2wav_cosyvoice3 import CosyVoice3_Token2Wav try: torch.multiprocessing.set_start_method("spawn") except RuntimeError: pass async def send_request_async(client, url, payload): response = await client.post(url, json=payload, timeout=None) response.raise_for_status() response_json = response.json() return response_json['choices'][0]['message']['content'] async def send_batch_requests_async(api_base, model_name, chats, temperature, top_p, top_k): async with httpx.AsyncClient() as client: tasks = [] for chat in chats: payload = { "model": model_name, "messages": chat, "max_tokens": 2048, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": 1.1, "stop": ["<|eos1|>", "<|eos|>"], "stream": False, } tasks.append(send_request_async(client, api_base, payload)) return await asyncio.gather(*tasks) def extract_speech_ids(speech_tokens_str): """Extract speech IDs from token strings like <|s_23456|>""" speech_ids = [] for token_str in speech_tokens_str: if token_str.startswith('<|s_') and token_str.endswith('|>'): num_str = token_str[4:-2] num = int(num_str) speech_ids.append(num) else: print(f"Unexpected token: {token_str}") return speech_ids def convert_cosy3_tokens_to_speech_id_str(cosy3_tokens): """Convert CosyVoice3 tokens to speech IDs string like <|s_23456|>""" if hasattr(cosy3_tokens, 'cpu'): cosy3_tokens = cosy3_tokens.cpu().numpy().tolist() speech_id_str = "" for token in cosy3_tokens: speech_id_str += f"<|s_{token}|>" return speech_id_str def get_args(): parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice3") parser.add_argument( "--split-name", type=str, default="wenetspeech4tts", help="huggingface dataset split name", ) parser.add_argument( "--output-dir", required=True, type=str, help="dir to save result", ) parser.add_argument( "--batch-size", default=1, type=int, help="batch size (per-device) for LLM inference", ) parser.add_argument( "--token2wav-batch-size", default=1, type=int, help="batch size (per-device) for token2wav inference", ) parser.add_argument( "--num-workers", type=int, default=0, help="workers for dataloader", ) parser.add_argument( "--prefetch", type=int, default=None, help="prefetch for dataloader", ) parser.add_argument( "--llm-model-name-or-path", required=True, type=str, help="CosyVoice3 HF LLM path (e.g. ./hf_cosyvoice3_llm)", ) parser.add_argument( "--token2wav-path", required=True, type=str, help="CosyVoice3 model path (e.g. /workspace_yuekai/HF/Fun-CosyVoice3-0.5B-2512)", ) parser.add_argument( "--enable-trt", action="store_true", help="Enable TensorRT for flow decoder estimator", ) parser.add_argument( "--streaming", action="store_true", help="Enable streaming for flow decoder estimator", ) parser.add_argument( "--top-p", type=float, default=0.95, help="top p for sampling", ) parser.add_argument( "--temperature", type=float, default=0.8, help="temperature for sampling", ) parser.add_argument( "--top-k", type=int, default=15, help="top k for sampling", ) parser.add_argument( "--backend", type=str, default="hf", choices=["hf", "trtllm", "vllm", "trtllm-serve"], help="Backend to use for LLM inference", ) parser.add_argument( "--engine-dir", type=str, default=None, help="TensorRT-LLM engine directory (required when backend is 'trtllm')", ) parser.add_argument( "--kv-cache-free-gpu-memory-fraction", type=float, default=0.6, help="Fraction of GPU memory to free for KV cache (TensorRT-LLM only)", ) parser.add_argument( "--openai-api-base", type=str, default="http://localhost:8000/v1/chat/completions", help="OpenAI API base URL (for trtllm-serve backend)", ) parser.add_argument( "--openai-model-name", type=str, default="trt_engines_bfloat16", help="Model name to use with OpenAI API (for trtllm-serve backend)", ) parser.add_argument( "--epoch", type=int, default=1, help="Epoch to run", ) return parser.parse_args() def data_collator(batch, tokenizer, s3_tokenizer): """Data collator: extracts cosy3 tokens from prompt_audio using v3 s3 tokenizer.""" device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu") target_sample_rate = 16000 input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] mels, prompt_audio_cosy3tokens_list, full_text_list = [], [], [] chat_list = [] for item in batch: prompt_text, target_text = item["prompt_text"], item["target_text"] prompt_text_list.append(prompt_text) full_text = 'You are a helpful assistant.<|endofprompt|>' + prompt_text + target_text full_text_list.append(full_text) # Get prompt audio (convert to 16kHz for s3 tokenizer) ref_audio = torch.from_numpy(item["prompt_audio"]["array"]).float().unsqueeze(0) ref_sr = item["prompt_audio"]["sampling_rate"] if ref_sr != target_sample_rate: ref_audio = torchaudio.transforms.Resample(ref_sr, target_sample_rate)(ref_audio) prompt_audio_list.append(ref_audio) # Extract cosy3 tokens from prompt_audio using v3 s3 tokenizer mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0))) # Batch tokenization with v3 tokenizer if len(mels) > 0: mels_padded, mels_lens = s3tokenizer.padding(mels) codes, codes_lens = s3_tokenizer.quantize(mels_padded.to(device), mels_lens.to(device)) for i in range(len(codes)): prompt_audio_cosy3tokens_list.append(codes[i, :codes_lens[i].item()]) # Build LLM inputs for i, prompt_audio_cosy3tokens in enumerate(prompt_audio_cosy3tokens_list): prompt_audio_cosy3_id_str = convert_cosy3_tokens_to_speech_id_str( prompt_audio_cosy3tokens) chat = [ {"role": "user", "content": full_text_list[i]}, {"role": "assistant", "content": prompt_audio_cosy3_id_str} ] chat_list.append(chat) input_ids = tokenizer.apply_chat_template( chat, tokenize=True, return_tensors='pt', continue_final_message=True) input_ids_list.append(input_ids.squeeze(0)) ids = [item["id"] for item in batch] return { "input_ids": input_ids_list, "ids": ids, "prompt_text": prompt_text_list, "prompt_audio_list": prompt_audio_list, "chat_list": chat_list, } def main(args): os.makedirs(args.output_dir, exist_ok=True) assert torch.cuda.is_available() local_rank = 0 device = torch.device(f"cuda:{local_rank}") tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path) if args.backend == "hf": model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path) model.eval() model.to(device) runner = None elif args.backend == "trtllm": if args.engine_dir is None: raise ValueError("--engine-dir is required when backend is 'trtllm'") runtime_rank = tensorrt_llm.mpi_rank() model = None runner_kwargs = dict( engine_dir=args.engine_dir, rank=runtime_rank, max_output_len=2048, enable_context_fmha_fp32_acc=False, max_batch_size=args.batch_size, max_input_len=512, kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction, cuda_graph_mode=False, gather_generation_logits=False, ) runner = ModelRunnerCpp.from_dir(**runner_kwargs) elif args.backend == "vllm": model = LLM(model=args.llm_model_name_or_path, gpu_memory_utilization=0.4) runner = None elif args.backend == "trtllm-serve": model = None runner = None else: raise ValueError(f"Unsupported backend: {args.backend}") token2wav_model = CosyVoice3_Token2Wav( model_dir=args.token2wav_path, enable_trt=args.enable_trt, device_id=local_rank, streaming=args.streaming ) # Load v3 s3 tokenizer for prompt audio tokenization in data_collator s3_tokenizer = s3tokenizer.load_model( f"{args.token2wav_path}/speech_tokenizer_v3.onnx" ).to(device).eval() dataset = load_dataset( "yuekai/seed_tts_cosy2", split=args.split_name, trust_remote_code=True, ) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, prefetch_factor=args.prefetch, collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer), ) for epoch in range(args.epoch): print(f"Running epoch {epoch}") total_llm_time = 0 total_token2wav_time = 0 total_data_load_time = 0 total_llm_post_processing_time = 0 total_audio_save_time = 0 total_audio_samples = 0 start_time = time.time() progress_bar = tqdm(total=len(dataset), desc="Processing", unit="wavs") last_batch_end_time = time.time() for batch in dataloader: data_loaded_time = time.time() total_data_load_time += data_loaded_time - last_batch_end_time with torch.no_grad(): llm_start_time = time.time() if args.backend == "hf": input_ids_list = batch["input_ids"] if len(input_ids_list) == 1: input_ids = input_ids_list[0].unsqueeze(0) attention_mask = torch.ones_like(input_ids) else: max_len = max([len(ids) for ids in input_ids_list]) input_ids_list_new = [ torch.cat([ids, torch.full((max_len - len(ids),), tokenizer.pad_token_id)]) for ids in input_ids_list ] input_ids = torch.stack(input_ids_list_new) attention_mask = torch.zeros_like(input_ids) for i in range(len(input_ids_list)): attention_mask[i, :len(input_ids_list[i])] = 1 outputs = model.generate( input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), max_new_tokens=2048, do_sample=True, top_p=args.top_p, temperature=args.temperature, repetition_penalty=1.1, top_k=args.top_k, ) torch.cuda.synchronize() elif args.backend == "trtllm": batch_input_ids = list(batch["input_ids"]) input_lengths = [x.size(0) for x in batch_input_ids] end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id outputs = runner.generate( batch_input_ids=batch_input_ids, max_new_tokens=2048, end_id=end_id, pad_id=end_id, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=1.1, num_return_sequences=1, streaming=False, output_sequence_lengths=True, output_generation_logits=False, return_dict=True, return_all_generated_tokens=False ) torch.cuda.synchronize() output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"] num_output_sents, num_beams, _ = output_ids.size() assert num_beams == 1 batch_size = len(batch["input_ids"]) num_return_sequences = num_output_sents // batch_size assert num_return_sequences == 1 outputs = [] for i in range(batch_size * num_return_sequences): batch_idx = i // num_return_sequences output_begin = input_lengths[batch_idx] output_end = sequence_lengths[i][0] outputs_i = output_ids[i][0][:output_end].tolist() outputs.append(outputs_i) elif args.backend == "vllm": input_ids_list = [ids.tolist() for ids in batch["input_ids"]] sampling_params = SamplingParams( temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, repetition_penalty=1.1, max_tokens=2048, ) outputs = model.generate(prompt_token_ids=input_ids_list, sampling_params=sampling_params) for j, output in enumerate(outputs): outputs[j] = input_ids_list[j] + output.outputs[0].token_ids elif args.backend == "trtllm-serve": if args.batch_size > 1: outputs = asyncio.run(send_batch_requests_async( args.openai_api_base, args.openai_model_name, batch["chat_list"], args.temperature, args.top_p, args.top_k, )) else: outputs = [] for chat in batch["chat_list"]: payload = { "model": args.openai_model_name, "messages": chat, "max_tokens": 2048, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, "repetition_penalty": 1.1, "stop": ["<|eos1|>", "<|eos|>"], "stream": False, } response = requests.post(args.openai_api_base, json=payload) response.raise_for_status() response_json = response.json() generated_content = response_json['choices'][0]['message']['content'] outputs.append(generated_content) llm_end_time = time.time() total_llm_time += (llm_end_time - llm_start_time) items_for_token_2wav = [] for i in range(len(batch["ids"])): llm_post_processing_start_time = time.time() if args.backend == "trtllm-serve": speech_tokens_str = outputs[i].strip().split('><') if len(speech_tokens_str) > 1: speech_tokens_str = [ t if t.startswith('<') else '<' + t for t in speech_tokens_str ] speech_tokens_str = [ t if t.endswith('>') else t + '>' for t in speech_tokens_str ] speech_ids = extract_speech_ids(speech_tokens_str) else: input_length = len(batch["input_ids"][i]) generated_ids = outputs[i][input_length:] speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) speech_ids = extract_speech_ids(speech_tokens_str) print(i, speech_ids[:10], "...", f"total={len(speech_ids)}") if len(speech_ids) == 0: print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") llm_post_processing_end_time = time.time() total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time continue current_prompt_audio = batch["prompt_audio_list"][i] llm_post_processing_end_time = time.time() total_llm_post_processing_time += llm_post_processing_end_time - llm_post_processing_start_time items_for_token_2wav.append({ "speech_ids": speech_ids, "prompt_audio": current_prompt_audio.squeeze(0), "id": batch["ids"][i] }) for i in range(0, len(items_for_token_2wav), args.token2wav_batch_size): t2w_batch = items_for_token_2wav[i:i + args.token2wav_batch_size] if not t2w_batch: continue t2w_speech_tokens = [item["speech_ids"] for item in t2w_batch] t2w_prompt_audios = [item["prompt_audio"] for item in t2w_batch] t2w_sample_rates = [16000] * len(t2w_batch) token2wav_start_time = time.time() generated_wavs = token2wav_model( t2w_speech_tokens, t2w_prompt_audios, t2w_sample_rates, streaming=args.streaming, ) token2wav_end_time = time.time() total_token2wav_time += (token2wav_end_time - token2wav_start_time) audio_save_start_time = time.time() for j, audio_hat in enumerate(generated_wavs): wav = audio_hat.squeeze().cpu().numpy() total_audio_samples += len(wav) sf.write(f"{args.output_dir}/{t2w_batch[j]['id']}.wav", wav, 24000) print(f"Generated audio for sample {t2w_batch[j]['id']} with {len(t2w_speech_tokens[j])} tokens") audio_save_end_time = time.time() total_audio_save_time += audio_save_end_time - audio_save_start_time progress_bar.update(len(batch["ids"])) last_batch_end_time = time.time() progress_bar.close() end_time = time.time() total_audio_duration_seconds = total_audio_samples / 24000 log_file_path = os.path.join(args.output_dir, "log.txt") with open(log_file_path, 'w') as f: log_data = { "args": vars(args), "data_load_time_seconds": total_data_load_time, "llm_time_seconds": total_llm_time, "llm_post_processing_time_seconds": total_llm_post_processing_time, "token2wav_time_seconds": total_token2wav_time, "audio_save_time_seconds": total_audio_save_time, "total_audio_duration_seconds": total_audio_duration_seconds, "pipeline_time_seconds": end_time - start_time, } print(log_data) f.write(json.dumps(log_data, indent=4)) print(f"Metrics logged to {log_file_path}") if __name__ == "__main__": args = get_args() if args.backend == "vllm": from vllm import LLM, SamplingParams elif args.backend == "trtllm": import tensorrt_llm from tensorrt_llm.runtime import ModelRunnerCpp elif args.backend == "hf": from transformers import AutoModelForCausalLM elif args.backend == "trtllm-serve": pass else: raise ValueError(f"Unsupported backend: {args.backend}") main(args)