token2wav.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Example Usage
  16. CUDA_VISIBLE_DEVICES=0 \
  17. python3 token2wav.py --enable-trt || exit 1
  18. """
  19. import torch
  20. from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
  21. from flashcosyvoice.modules.hifigan import HiFTGenerator
  22. from flashcosyvoice.utils.audio import mel_spectrogram
  23. import torchaudio.compliance.kaldi as kaldi
  24. import onnxruntime
  25. import s3tokenizer
  26. from torch.utils.data import DataLoader
  27. from datasets import load_dataset
  28. import torchaudio
  29. import os
  30. import logging
  31. import argparse
  32. import queue
  33. import time
  34. def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
  35. import tensorrt as trt
  36. logging.info("Converting onnx to trt...")
  37. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  38. logger = trt.Logger(trt.Logger.INFO)
  39. builder = trt.Builder(logger)
  40. network = builder.create_network(network_flags)
  41. parser = trt.OnnxParser(network, logger)
  42. config = builder.create_builder_config()
  43. # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
  44. if fp16:
  45. config.set_flag(trt.BuilderFlag.FP16)
  46. profile = builder.create_optimization_profile()
  47. # load onnx model
  48. with open(onnx_model, "rb") as f:
  49. if not parser.parse(f.read()):
  50. for error in range(parser.num_errors):
  51. print(parser.get_error(error))
  52. raise ValueError('failed to parse {}'.format(onnx_model))
  53. # set input shapes
  54. for i in range(len(trt_kwargs['input_names'])):
  55. profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
  56. tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
  57. # set input and output data type
  58. for i in range(network.num_inputs):
  59. input_tensor = network.get_input(i)
  60. input_tensor.dtype = tensor_dtype
  61. for i in range(network.num_outputs):
  62. output_tensor = network.get_output(i)
  63. output_tensor.dtype = tensor_dtype
  64. config.add_optimization_profile(profile)
  65. engine_bytes = builder.build_serialized_network(network, config)
  66. # save trt engine
  67. with open(trt_model, "wb") as f:
  68. f.write(engine_bytes)
  69. logging.info("Succesfully convert onnx to trt...")
  70. class TrtContextWrapper:
  71. def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
  72. self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
  73. self.trt_engine = trt_engine
  74. self.device = device
  75. for _ in range(trt_concurrent):
  76. trt_context = trt_engine.create_execution_context()
  77. trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
  78. assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
  79. self.trt_context_pool.put([trt_context, trt_stream])
  80. assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
  81. def acquire_estimator(self):
  82. return self.trt_context_pool.get(), self.trt_engine
  83. def release_estimator(self, context, stream):
  84. self.trt_context_pool.put([context, stream])
  85. class CosyVoice2_Token2Wav(torch.nn.Module):
  86. def __init__(self, model_dir: str = "./CosyVoice2-0.5B", enable_trt: bool = False, device_id: int = 0):
  87. super().__init__()
  88. self.device_id = device_id
  89. self.device = f"cuda:{device_id}"
  90. self.flow = CausalMaskedDiffWithXvec()
  91. self.flow.half()
  92. self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
  93. self.flow.to(self.device).eval()
  94. self.hift = HiFTGenerator()
  95. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
  96. self.hift.load_state_dict(hift_state_dict, strict=True)
  97. self.hift.to(self.device).eval()
  98. option = onnxruntime.SessionOptions()
  99. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  100. option.intra_op_num_threads = 1
  101. self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
  102. self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2.onnx").to(self.device).eval()
  103. gpu = "l20"
  104. if enable_trt:
  105. self.load_trt(f'{model_dir}/flow.decoder.estimator.fp16.dynamic_batch.{gpu}.plan',
  106. f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
  107. 1,
  108. True)
  109. self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
  110. f'{model_dir}/campplus.onnx',
  111. 1,
  112. False)
  113. def forward_spk_embedding(self, spk_feat):
  114. if isinstance(self.spk_model, onnxruntime.InferenceSession):
  115. return self.spk_model.run(
  116. None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
  117. )[0].flatten().tolist()
  118. else:
  119. [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
  120. # NOTE need to synchronize when switching stream
  121. with torch.cuda.device(self.device_id):
  122. torch.cuda.current_stream().synchronize()
  123. spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
  124. batch_size = spk_feat.size(0)
  125. with stream:
  126. spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
  127. output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
  128. data_ptrs = [spk_feat.contiguous().data_ptr(),
  129. output_tensor.contiguous().data_ptr()]
  130. for i, j in enumerate(data_ptrs):
  131. spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
  132. # run trt engine
  133. assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
  134. torch.cuda.current_stream().synchronize()
  135. self.spk_model.release_estimator(spk_model, stream)
  136. return output_tensor.cpu().numpy().flatten().tolist()
  137. def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
  138. if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
  139. trt_kwargs = self.get_spk_trt_kwargs()
  140. convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
  141. import tensorrt as trt
  142. with open(spk_model, 'rb') as f:
  143. spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  144. assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
  145. self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
  146. def get_spk_trt_kwargs(self):
  147. min_shape = [(1, 4, 80)]
  148. opt_shape = [(1, 500, 80)]
  149. max_shape = [(1, 3000, 80)]
  150. input_names = ["input"]
  151. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  152. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, fp16=True):
  153. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  154. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  155. trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_bs=2, max_batch_size=16)
  156. convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, fp16)
  157. del self.flow.decoder.estimator
  158. import tensorrt as trt
  159. with open(flow_decoder_estimator_model, 'rb') as f:
  160. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  161. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  162. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  163. def get_trt_kwargs_dynamic_batch(self, opt_bs=2, max_batch_size=64):
  164. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
  165. opt_shape = [(opt_bs * 2, 80, 500), (opt_bs * 2, 1, 500), (opt_bs * 2, 80, 500), (opt_bs * 2, 80, 500), (opt_bs * 2,), (opt_bs * 2, 80)]
  166. max_shape = [(max_batch_size * 2, 80, 3000), (max_batch_size * 2, 1, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2, 80, 3000), (max_batch_size * 2,),
  167. (max_batch_size * 2, 80)]
  168. input_names = ["x", "mask", "mu", "cond", "t", "spks"]
  169. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  170. def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
  171. prompt_speech_tokens_list, prompt_speech_mels_list = [], []
  172. for audio in prompt_audios_list:
  173. assert len(audio.shape) == 1
  174. log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
  175. prompt_speech_mels_list.append(log_mel)
  176. prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
  177. prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
  178. prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
  179. )
  180. for i in range(len(prompt_speech_tokens)):
  181. speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
  182. prompt_speech_tokens_list.append(speech_tokens_i)
  183. return prompt_speech_tokens_list
  184. def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
  185. spk_emb_for_flow = []
  186. for audio in prompt_audios_list:
  187. assert len(audio.shape) == 1
  188. spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
  189. spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
  190. spk_emb = self.forward_spk_embedding(spk_feat)
  191. spk_emb_for_flow.append(spk_emb)
  192. spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
  193. return spk_emb_for_flow
  194. def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
  195. prompt_mels_for_flow = []
  196. prompt_mels_lens_for_flow = []
  197. for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
  198. assert len(audio.shape) == 1
  199. audio = audio.unsqueeze(0)
  200. if sample_rate != 24000:
  201. audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
  202. mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
  203. mel_len = mel.shape[0]
  204. prompt_mels_for_flow.append(mel)
  205. prompt_mels_lens_for_flow.append(mel_len)
  206. prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
  207. prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
  208. return prompt_mels_for_flow, prompt_mels_lens_for_flow
  209. def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor,
  210. prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
  211. batch_size = prompt_mels_for_flow.shape[0]
  212. flow_inputs = []
  213. flow_inputs_lens = []
  214. for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
  215. flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
  216. flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
  217. flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
  218. flow_inputs_lens = torch.tensor(flow_inputs_lens)
  219. with torch.amp.autocast(self.device, dtype=torch.float16):
  220. generated_mels, generated_mels_lens = self.flow(
  221. flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
  222. prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device),
  223. streaming=False, finalize=True
  224. )
  225. return generated_mels, generated_mels_lens
  226. def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
  227. batch_size = generated_mels.shape[0]
  228. generated_wavs = []
  229. for i in range(batch_size):
  230. mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
  231. wav, _ = self.hift(speech_feat=mel)
  232. generated_wavs.append(wav)
  233. return generated_wavs
  234. @torch.inference_mode()
  235. def forward(
  236. self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
  237. ):
  238. # assert all item in prompt_audios_sample_rate is 16000
  239. assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
  240. prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
  241. prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
  242. spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
  243. generated_mels, generated_mels_lens = self.forward_flow(
  244. prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
  245. generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
  246. return generated_wavs
  247. def collate_fn(batch):
  248. ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
  249. for _, item in enumerate(batch):
  250. generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
  251. audio = torch.from_numpy(item['prompt_audio']['array']).float()
  252. prompt_audios_list.append(audio)
  253. prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
  254. ids.append(item['id'])
  255. return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
  256. def get_args():
  257. parser = argparse.ArgumentParser()
  258. parser.add_argument("--enable-trt", action="store_true")
  259. parser.add_argument("--model-dir", type=str, default="./CosyVoice2-0.5B")
  260. parser.add_argument("--batch-size", type=int, default=4)
  261. parser.add_argument("--output-dir", type=str, default="generated_wavs")
  262. parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
  263. parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
  264. return parser.parse_args()
  265. if __name__ == "__main__":
  266. args = get_args()
  267. model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
  268. # mkdir output_dir if not exists
  269. if not os.path.exists(args.output_dir):
  270. os.makedirs(args.output_dir)
  271. dataset_name = "yuekai/seed_tts_cosy2"
  272. dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
  273. data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
  274. for _ in range(args.warmup):
  275. start_time = time.time()
  276. for batch in data_loader:
  277. ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
  278. generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
  279. for id, wav in zip(ids, generated_wavs):
  280. torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
  281. end_time = time.time()
  282. epoch_time = end_time - start_time
  283. print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")