""" Example Usage CUDA_VISIBLE_DEVICES=0 \ python3 token2wav_cosyvoice3.py --enable-trt || exit 1 """ import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import onnxruntime import s3tokenizer import os import logging import argparse import queue import time import numpy as np from functools import partial from hyperpyyaml import load_hyperpyyaml from matcha.utils.audio import mel_spectrogram as matcha_mel_spectrogram from torch.utils.data import DataLoader from datasets import load_dataset logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # CosyVoice3 mel params from cosyvoice3.yaml (fmax=None, NOT 8000) mel_spectrogram = partial(matcha_mel_spectrogram, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, fmin=0, fmax=None, center=False) def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocast_mode=False): import tensorrt as trt logging.info("Converting onnx to trt...") if autocast_mode: network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) else: network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB if not autocast_mode: if fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() # load onnx model with open(onnx_model, "rb") as f: if not parser.parse(f.read()): for error in range(parser.num_errors): print(parser.get_error(error)) raise ValueError('failed to parse {}'.format(onnx_model)) # set input shapes for i in range(len(trt_kwargs['input_names'])): profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT # set input and output data type for i in range(network.num_inputs): input_tensor = network.get_input(i) input_tensor.dtype = tensor_dtype for i in range(network.num_outputs): output_tensor = network.get_output(i) output_tensor.dtype = tensor_dtype config.add_optimization_profile(profile) engine_bytes = builder.build_serialized_network(network, config) # save trt engine with open(trt_model, "wb") as f: f.write(engine_bytes) logging.info("Succesfully convert onnx to trt...") class TrtContextWrapper: def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) self.trt_engine = trt_engine self.device = device for _ in range(trt_concurrent): trt_context = trt_engine.create_execution_context() trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device))) assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) self.trt_context_pool.put([trt_context, trt_stream]) assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' def acquire_estimator(self): return self.trt_context_pool.get(), self.trt_engine def release_estimator(self, context, stream): self.trt_context_pool.put([context, stream]) class CosyVoice3_Token2Wav(torch.nn.Module): def __init__(self, model_dir, enable_trt=False, device_id=0, autocast_mode=True, streaming=False): super().__init__() self.device_id = device_id self.device = f"cuda:{device_id}" self.autocast_mode = autocast_mode self.streaming = streaming # Load flow and hift from cosyvoice3.yaml with open(f"{model_dir}/cosyvoice3.yaml", "r") as f: configs = load_hyperpyyaml(f, overrides={ 'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN') }) self.flow = configs['flow'] self.flow.load_state_dict( torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True ) self.flow.to(self.device).eval() self.hift = configs['hift'] hift_state_dict = { k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items() } self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.to(self.device).eval() # Speaker embedding model (campplus) option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 self.spk_model = onnxruntime.InferenceSession( f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"] ) # Audio tokenizer v3 self.audio_tokenizer = s3tokenizer.load_model( f"{model_dir}/speech_tokenizer_v3.onnx" ).to(self.device).eval() self.fp16 = enable_trt if enable_trt: self.flow.half() self.load_trt(model_dir) self.load_spk_trt(model_dir) def load_trt(self, model_dir, trt_concurrent=1): streaming_prefix = 'streaming.' if self.streaming else '' if self.autocast_mode: onnx_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}autocast_fp16.onnx' trt_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}autocast_fp16.{self.device_id}.plan' else: onnx_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}fp32.onnx' trt_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}fp32.{self.device_id}.plan' if not os.path.exists(trt_path) or os.path.getsize(trt_path) == 0: trt_kwargs = self.get_trt_kwargs() convert_onnx_to_trt(trt_path, trt_kwargs, onnx_path, fp16=True, autocast_mode=self.autocast_mode) del self.flow.decoder.estimator import tensorrt as trt with open(trt_path, 'rb') as f: estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) assert estimator_engine is not None, 'failed to load trt {}'.format(trt_path) self.flow.decoder.estimator = TrtContextWrapper( estimator_engine, trt_concurrent=trt_concurrent, device=self.device ) def get_trt_kwargs(self): # CosyVoice3 DiT estimator has 6 inputs: x, mask, mu, t, spks, cond # Only inputs with dynamic dims need optimization profiles. # t=[2(fixed)] and spks=[2(fixed),80(fixed)] are fully fixed, TRT infers from ONNX. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] input_names = ["x", "mask", "mu", "cond"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} def load_spk_trt(self, model_dir, trt_concurrent=1, fp16=False): spk_trt_path = f'{model_dir}/campplus.{self.device_id}.fp32.plan' spk_onnx_path = f'{model_dir}/campplus.onnx' if not os.path.exists(spk_trt_path) or os.path.getsize(spk_trt_path) == 0: trt_kwargs = self.get_spk_trt_kwargs() convert_onnx_to_trt(spk_trt_path, trt_kwargs, spk_onnx_path, fp16) import tensorrt as trt with open(spk_trt_path, 'rb') as f: spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read()) assert spk_engine is not None, 'failed to load trt {}'.format(spk_trt_path) self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device) def get_spk_trt_kwargs(self): min_shape = [(1, 4, 80)] opt_shape = [(1, 500, 80)] max_shape = [(1, 3000, 80)] input_names = ["input"] return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names} def forward_spk_embedding(self, spk_feat): if isinstance(self.spk_model, onnxruntime.InferenceSession): return self.spk_model.run( None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} )[0].flatten().tolist() else: [spk_model, stream], trt_engine = self.spk_model.acquire_estimator() with torch.cuda.device(self.device_id): torch.cuda.current_stream().synchronize() spk_feat = spk_feat.unsqueeze(dim=0).to(self.device) batch_size = spk_feat.size(0) with stream: spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80)) output_tensor = torch.empty((batch_size, 192), device=spk_feat.device) data_ptrs = [spk_feat.contiguous().data_ptr(), output_tensor.contiguous().data_ptr()] for i, j in enumerate(data_ptrs): spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j) assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True torch.cuda.current_stream().synchronize() self.spk_model.release_estimator(spk_model, stream) return output_tensor.cpu().numpy().flatten().tolist() def prompt_audio_tokenization(self, prompt_audios_list): prompt_speech_tokens_list, prompt_speech_mels_list = [], [] for audio in prompt_audios_list: assert len(audio.shape) == 1 log_mel = s3tokenizer.log_mel_spectrogram(audio) prompt_speech_mels_list.append(log_mel) prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list) prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize( prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device) ) for i in range(len(prompt_speech_tokens)): speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist() prompt_speech_tokens_list.append(speech_tokens_i) return prompt_speech_tokens_list def get_spk_emb(self, prompt_audios_list): spk_emb_for_flow = [] for audio in prompt_audios_list: assert len(audio.shape) == 1 spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) spk_emb = self.forward_spk_embedding(spk_feat) spk_emb_for_flow.append(spk_emb) spk_emb_for_flow = torch.tensor(spk_emb_for_flow) return spk_emb_for_flow def get_prompt_mels(self, prompt_audios_list, prompt_audios_sample_rate): prompt_mels_for_flow = [] prompt_mels_lens_for_flow = [] for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate): assert len(audio.shape) == 1 audio = audio.unsqueeze(0) if sample_rate != 24000: audio = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=24000)(audio) # CosyVoice3: fmax=None (Nyquist), matching cosyvoice3.yaml mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, 80] prompt_mels_for_flow.append(mel) prompt_mels_lens_for_flow.append(mel.shape[0]) prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence( prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', 80] prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) return prompt_mels_for_flow, prompt_mels_lens_for_flow def forward_flow(self, prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow): batch_size = len(generated_speech_tokens_list) generated_mels_list = [] # CausalMaskedDiffWithDiT.inference asserts batch_size==1, so iterate per-sample for i in range(batch_size): token = torch.tensor([generated_speech_tokens_list[i]]).to(self.device) token_len = torch.tensor([len(generated_speech_tokens_list[i])]).to(self.device) prompt_token = torch.tensor([prompt_speech_tokens_list[i]]).to(self.device) prompt_token_len = torch.tensor([len(prompt_speech_tokens_list[i])]).to(self.device) prompt_feat = prompt_mels_for_flow[i:i+1, :prompt_mels_lens_for_flow[i]].to(self.device) prompt_feat_len = prompt_mels_lens_for_flow[i:i+1].to(self.device) embedding = spk_emb_for_flow[i:i+1].to(self.device) # CausalMaskedDiffWithDiT.inference returns mel already without prompt portion with torch.cuda.amp.autocast(self.fp16): mel, _ = self.flow.inference( token=token, token_len=token_len, prompt_token=prompt_token, prompt_token_len=prompt_token_len, prompt_feat=prompt_feat, prompt_feat_len=prompt_feat_len, embedding=embedding, streaming=False, finalize=True ) generated_mels_list.append(mel) return generated_mels_list def forward_hift(self, generated_mels_list): generated_wavs = [] for mel in generated_mels_list: # CausalHiFTGenerator.inference with finalize=True wav, _ = self.hift.inference(speech_feat=mel, finalize=True) generated_wavs.append(wav) return generated_wavs def forward_stream(self, generated_speech_tokens, prompt_speech_tokens, prompt_feat, embedding, token_hop_len=25, stream_scale_factor=2, token_max_hop_len=100): """Streaming token2wav for a single sample: process tokens in chunks.""" prompt_token = torch.tensor([prompt_speech_tokens]).to(self.device) prompt_token_len = torch.tensor([len(prompt_speech_tokens)]).to(self.device) prompt_feat = prompt_feat.to(self.device) prompt_feat_len = torch.tensor([prompt_feat.shape[1]]).to(self.device) embedding = embedding.to(self.device) pre_lookahead_len = self.flow.pre_lookahead_len token_mel_ratio = self.flow.token_mel_ratio # Align first chunk with hop_len boundary prompt_token_pad = int( np.ceil(prompt_token.shape[1] / token_hop_len) * token_hop_len - prompt_token.shape[1] ) total_tokens = len(generated_speech_tokens) token_offset = 0 current_hop = token_hop_len hift_cache_mel = None speech_offset = 0 audio_chunks = [] while token_offset < total_tokens: this_hop = current_hop + prompt_token_pad if token_offset == 0 else current_hop remaining = total_tokens - token_offset if remaining >= this_hop + pre_lookahead_len: end_idx = token_offset + this_hop + pre_lookahead_len this_token = torch.tensor([generated_speech_tokens[:end_idx]]).to(self.device) finalize = False else: this_token = torch.tensor([generated_speech_tokens]).to(self.device) finalize = True with torch.cuda.amp.autocast(self.fp16): mel, _ = self.flow.inference( token=this_token, token_len=torch.tensor([this_token.shape[1]]).to(self.device), prompt_token=prompt_token, prompt_token_len=prompt_token_len, prompt_feat=prompt_feat, prompt_feat_len=prompt_feat_len, embedding=embedding, streaming=True, finalize=finalize, ) mel = mel[:, :, token_offset * token_mel_ratio:] if hift_cache_mel is not None: mel = torch.concat([hift_cache_mel, mel], dim=2) hift_cache_mel = mel tts_speech, _ = self.hift.inference(speech_feat=mel, finalize=finalize) tts_speech = tts_speech[:, speech_offset:] speech_offset += tts_speech.shape[1] logger.info(f"[stream] token_offset={token_offset}, this_hop={this_hop}, " f"mel_shape={mel.shape}, speech_len={tts_speech.shape[1]}, finalize={finalize}") audio_chunks.append(tts_speech) token_offset += this_hop if not finalize: current_hop = min(token_max_hop_len, current_hop * stream_scale_factor) else: break return torch.cat(audio_chunks, dim=1) @torch.inference_mode() def forward(self, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate, streaming=False): assert all(sr == 16000 for sr in prompt_audios_sample_rate) prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list) prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels( prompt_audios_list, prompt_audios_sample_rate) spk_emb_for_flow = self.get_spk_emb(prompt_audios_list) # Align prompt_speech_feat and prompt_speech_token to exact 2:1 ratio # (matches frontend.frontend_zero_shot logic) for i in range(len(prompt_speech_tokens_list)): token_len = min(int(prompt_mels_lens_for_flow[i].item() / 2), len(prompt_speech_tokens_list[i])) prompt_speech_tokens_list[i] = prompt_speech_tokens_list[i][:token_len] prompt_mels_lens_for_flow[i] = 2 * token_len if streaming: generated_wavs = [] for i in range(len(generated_speech_tokens_list)): prompt_feat = prompt_mels_for_flow[i:i+1, :prompt_mels_lens_for_flow[i]] embedding = spk_emb_for_flow[i:i+1] wav = self.forward_stream( generated_speech_tokens_list[i], prompt_speech_tokens_list[i], prompt_feat, embedding, ) generated_wavs.append(wav) return generated_wavs generated_mels_list = self.forward_flow( prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow) generated_wavs = self.forward_hift(generated_mels_list) return generated_wavs