token2wav_dit.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Example Usage
  16. CUDA_VISIBLE_DEVICES=0 \
  17. python3 token2wav.py --enable-trt || exit 1
  18. """
  19. import torch
  20. # from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
  21. from flashcosyvoice.modules.hifigan import HiFTGenerator
  22. from flashcosyvoice.utils.audio import mel_spectrogram
  23. import torchaudio.compliance.kaldi as kaldi
  24. import onnxruntime
  25. import s3tokenizer
  26. from torch.utils.data import DataLoader
  27. from datasets import load_dataset
  28. import torchaudio
  29. import os
  30. import logging
  31. import argparse
  32. import queue
  33. import time
  34. import numpy as np
  35. from hyperpyyaml import load_hyperpyyaml
  36. def fade_in_out(fade_in_mel:torch.Tensor, fade_out_mel:torch.Tensor, window:torch.Tensor):
  37. """perform fade_in_out in tensor style
  38. """
  39. mel_overlap_len = int(window.shape[0] / 2)
  40. fade_in_mel = fade_in_mel.clone()
  41. fade_in_mel[..., :mel_overlap_len] = \
  42. fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
  43. fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
  44. return fade_in_mel
  45. def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, dtype):
  46. import tensorrt as trt
  47. logging.info("Converting onnx to trt...")
  48. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  49. logger = trt.Logger(trt.Logger.INFO)
  50. builder = trt.Builder(logger)
  51. network = builder.create_network(network_flags)
  52. parser = trt.OnnxParser(network, logger)
  53. config = builder.create_builder_config()
  54. # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
  55. if dtype == torch.float16:
  56. config.set_flag(trt.BuilderFlag.FP16)
  57. elif dtype == torch.bfloat16:
  58. config.set_flag(trt.BuilderFlag.BF16)
  59. elif dtype == torch.float32:
  60. config.set_flag(trt.BuilderFlag.FP32)
  61. profile = builder.create_optimization_profile()
  62. # load onnx model
  63. with open(onnx_model, "rb") as f:
  64. if not parser.parse(f.read()):
  65. for error in range(parser.num_errors):
  66. print(parser.get_error(error))
  67. raise ValueError('failed to parse {}'.format(onnx_model))
  68. # set input shapes
  69. for i in range(len(trt_kwargs['input_names'])):
  70. profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
  71. if dtype == torch.float16:
  72. tensor_dtype = trt.DataType.HALF
  73. elif dtype == torch.bfloat16:
  74. tensor_dtype = trt.DataType.BF16
  75. elif dtype == torch.float32:
  76. tensor_dtype = trt.DataType.FLOAT
  77. else:
  78. raise ValueError('invalid dtype {}'.format(dtype))
  79. # set input and output data type
  80. for i in range(network.num_inputs):
  81. input_tensor = network.get_input(i)
  82. input_tensor.dtype = tensor_dtype
  83. for i in range(network.num_outputs):
  84. output_tensor = network.get_output(i)
  85. output_tensor.dtype = tensor_dtype
  86. config.add_optimization_profile(profile)
  87. engine_bytes = builder.build_serialized_network(network, config)
  88. # save trt engine
  89. with open(trt_model, "wb") as f:
  90. f.write(engine_bytes)
  91. logging.info("Succesfully convert onnx to trt...")
  92. class TrtContextWrapper:
  93. def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
  94. self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
  95. self.trt_engine = trt_engine
  96. self.device = device
  97. for _ in range(trt_concurrent):
  98. trt_context = trt_engine.create_execution_context()
  99. trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
  100. assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
  101. self.trt_context_pool.put([trt_context, trt_stream])
  102. assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
  103. def acquire_estimator(self):
  104. return self.trt_context_pool.get(), self.trt_engine
  105. def release_estimator(self, context, stream):
  106. self.trt_context_pool.put([context, stream])
  107. class CosyVoice2_Token2Wav(torch.nn.Module):
  108. def __init__(self, model_dir: str, enable_trt: bool = False, device_id: int = 0, streaming: bool = False, dtype: torch.dtype = torch.float16):
  109. super().__init__()
  110. self.device_id = device_id
  111. self.device = f"cuda:{device_id}"
  112. with open(f"{model_dir}/flow.yaml", "r") as f:
  113. configs = load_hyperpyyaml(f)
  114. self.flow = configs['flow']
  115. self.dtype = dtype
  116. self.flow.to(self.dtype)
  117. self.flow.load_state_dict(torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
  118. self.flow.to(self.device).eval()
  119. self.hift = HiFTGenerator()
  120. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()}
  121. self.hift.load_state_dict(hift_state_dict, strict=True)
  122. self.hift.to(self.device).eval()
  123. option = onnxruntime.SessionOptions()
  124. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  125. option.intra_op_num_threads = 1
  126. self.spk_model = onnxruntime.InferenceSession(f"{model_dir}/campplus.onnx", sess_options=option,
  127. providers=["CPUExecutionProvider"])
  128. self.audio_tokenizer = s3tokenizer.load_model(f"{model_dir}/speech_tokenizer_v2_25hz.onnx").to(self.device).eval()
  129. gpu="l20"
  130. if enable_trt:
  131. if streaming:
  132. self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.chunk.{gpu}.plan',
  133. f'{model_dir}/flow.decoder.estimator.chunk.fp32.dynamic_batch.simplify.onnx',
  134. 1,
  135. self.dtype, streaming)
  136. else:
  137. self.load_trt(f'{model_dir}/flow.decoder.estimator.{self.dtype}.dynamic_batch.{gpu}.plan',
  138. f'{model_dir}/flow.decoder.estimator.fp32.dynamic_batch.onnx',
  139. 1,
  140. self.dtype)
  141. self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
  142. f'{model_dir}/campplus.onnx',
  143. 1,
  144. False)
  145. self.streaming_flow_cache = {}
  146. self.speaker_cache = {}
  147. self.mel_cache_len = 8 # hard-coded, 160ms
  148. self.source_cache_len = int(self.mel_cache_len * 480) # 50hz mel -> 24kHz wave
  149. self.speech_window = torch.from_numpy(np.hamming(2 * self.source_cache_len)).cuda()
  150. # hifigan cache for streaming tts
  151. self.hift_cache_dict = {}
  152. def forward_spk_embedding(self, spk_feat):
  153. if isinstance(self.spk_model, onnxruntime.InferenceSession):
  154. return self.spk_model.run(
  155. None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
  156. )[0].flatten().tolist()
  157. else:
  158. [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
  159. # NOTE need to synchronize when switching stream
  160. with torch.cuda.device(self.device_id):
  161. torch.cuda.current_stream().synchronize()
  162. spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
  163. batch_size = spk_feat.size(0)
  164. with stream:
  165. spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
  166. output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
  167. data_ptrs = [spk_feat.contiguous().data_ptr(),
  168. output_tensor.contiguous().data_ptr()]
  169. for i, j in enumerate(data_ptrs):
  170. spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
  171. # run trt engine
  172. assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
  173. torch.cuda.current_stream().synchronize()
  174. self.spk_model.release_estimator(spk_model, stream)
  175. return output_tensor.cpu().numpy().flatten().tolist()
  176. def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
  177. if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
  178. trt_kwargs = self.get_spk_trt_kwargs()
  179. convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
  180. import tensorrt as trt
  181. with open(spk_model, 'rb') as f:
  182. spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  183. assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
  184. self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
  185. def get_spk_trt_kwargs(self):
  186. min_shape = [(1, 4, 80)]
  187. opt_shape = [(1, 500, 80)]
  188. max_shape = [(1, 3000, 80)]
  189. input_names = ["input"]
  190. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  191. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent=1, dtype=torch.float16, streaming=False):
  192. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  193. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  194. opt_batch_size = 2
  195. max_batch_size = 16
  196. if streaming:
  197. opt_batch_size, max_batch_size = 1, 1 # only support batch size 1 for streaming tts
  198. trt_kwargs = self.get_trt_kwargs_dynamic_batch(opt_batch_size=opt_batch_size, max_batch_size=max_batch_size, streaming=streaming)
  199. convert_onnx_to_trt(flow_decoder_estimator_model, trt_kwargs, flow_decoder_onnx_model, dtype)
  200. del self.flow.decoder.estimator
  201. import tensorrt as trt
  202. with open(flow_decoder_estimator_model, 'rb') as f:
  203. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  204. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  205. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  206. def get_trt_kwargs_dynamic_batch(self, opt_batch_size=2, max_batch_size=64, streaming=False):
  207. if streaming:
  208. min_shape = [(2, 80, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80), (16, 2, 1024, 2), (16, 2, 8, 0, 128)]
  209. opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80), (16, opt_batch_size*2, 1024, 2), (16, opt_batch_size*2, 8, 100, 128)]
  210. max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80), (16, max_batch_size*2, 1024, 2), (16, max_batch_size*2, 8, 1000, 128)]
  211. input_names = ["x", "mu", "cond", "t", "spks", "cnn_cache", "att_cache"]
  212. else:
  213. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (2,), (2, 80)]
  214. opt_shape = [(opt_batch_size*2, 80, 500), (opt_batch_size*2, 1, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2, 80, 500), (opt_batch_size*2,), (opt_batch_size*2, 80)]
  215. max_shape = [(max_batch_size*2, 80, 3000), (max_batch_size*2, 1, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2, 80, 3000), (max_batch_size*2,), (max_batch_size*2, 80)]
  216. input_names = ["x", "mask", "mu", "cond", "t", "spks"]
  217. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  218. def prompt_audio_tokenization(self, prompt_audios_list: list[torch.Tensor]) -> list[list[int]]:
  219. prompt_speech_tokens_list, prompt_speech_mels_list = [], []
  220. for audio in prompt_audios_list:
  221. assert len(audio.shape) == 1
  222. log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
  223. prompt_speech_mels_list.append(log_mel)
  224. prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
  225. prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
  226. prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
  227. )
  228. for i in range(len(prompt_speech_tokens)):
  229. speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
  230. prompt_speech_tokens_list.append(speech_tokens_i)
  231. return prompt_speech_tokens_list
  232. def get_spk_emb(self, prompt_audios_list: list[torch.Tensor]) -> torch.Tensor:
  233. spk_emb_for_flow = []
  234. for audio in prompt_audios_list:
  235. assert len(audio.shape) == 1
  236. spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
  237. spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
  238. spk_emb = self.forward_spk_embedding(spk_feat)
  239. spk_emb_for_flow.append(spk_emb)
  240. spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
  241. if self.dtype != torch.float32:
  242. spk_emb_for_flow = spk_emb_for_flow.to(self.dtype)
  243. return spk_emb_for_flow
  244. def get_prompt_mels(self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]):
  245. prompt_mels_for_flow = []
  246. prompt_mels_lens_for_flow = []
  247. for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
  248. assert len(audio.shape) == 1
  249. audio = audio.unsqueeze(0)
  250. if sample_rate != 24000:
  251. audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
  252. mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
  253. mel_len = mel.shape[0]
  254. prompt_mels_for_flow.append(mel)
  255. prompt_mels_lens_for_flow.append(mel_len)
  256. prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
  257. prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
  258. return prompt_mels_for_flow, prompt_mels_lens_for_flow
  259. def forward_flow(self, prompt_speech_tokens_list: list[list[int]], generated_speech_tokens_list: list[list[int]], prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor, spk_emb_for_flow: torch.Tensor):
  260. batch_size = prompt_mels_for_flow.shape[0]
  261. flow_inputs = []
  262. flow_inputs_lens = []
  263. for prompt_speech_tokens, generated_speech_tokens in zip(prompt_speech_tokens_list, generated_speech_tokens_list):
  264. flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
  265. flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
  266. flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
  267. flow_inputs_lens = torch.tensor(flow_inputs_lens)
  268. with torch.amp.autocast(self.device, dtype=torch.float16):
  269. generated_mels, generated_mels_lens = self.flow.inference(
  270. flow_inputs.to(self.device), flow_inputs_lens.to(self.device),
  271. prompt_mels_for_flow.to(self.device), prompt_mels_lens_for_flow.to(self.device), spk_emb_for_flow.to(self.device), 10
  272. )
  273. return generated_mels, generated_mels_lens
  274. def forward_hift(self, generated_mels: torch.Tensor, generated_mels_lens: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor):
  275. batch_size = generated_mels.shape[0]
  276. generated_wavs = []
  277. for i in range(batch_size):
  278. mel = generated_mels[i, :, prompt_mels_lens_for_flow[i].item():generated_mels_lens[i].item()].unsqueeze(0)
  279. wav, _ = self.hift(speech_feat=mel)
  280. generated_wavs.append(wav)
  281. return generated_wavs
  282. @torch.inference_mode()
  283. def forward(
  284. self, generated_speech_tokens_list: list[list[int]], prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
  285. ):
  286. # assert all item in prompt_audios_sample_rate is 16000
  287. assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
  288. prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio(prompt_audios_list, prompt_audios_sample_rate)
  289. generated_mels, generated_mels_lens = self.forward_flow(prompt_speech_tokens_list, generated_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
  290. generated_wavs = self.forward_hift(generated_mels, generated_mels_lens, prompt_mels_lens_for_flow)
  291. return generated_wavs
  292. def prepare_prompt_audio(
  293. self, prompt_audios_list: list[torch.Tensor], prompt_audios_sample_rate: list[int]
  294. ):
  295. # assert all item in prompt_audios_sample_rate is 16000
  296. assert all(sample_rate == 16000 for sample_rate in prompt_audios_sample_rate)
  297. prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
  298. prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(prompt_audios_list, prompt_audios_sample_rate)
  299. spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
  300. return prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
  301. def get_prompt_audio_cache_for_streaming_tts(
  302. self, prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow
  303. ):
  304. assert len(prompt_speech_tokens_list) == 1, "only support batch size 1 for streaming tts"
  305. for i, prompt_speech_tokens in enumerate(prompt_speech_tokens_list):
  306. prompt_speech_tokens_list[i] = torch.tensor(prompt_speech_tokens + prompt_speech_tokens_list[i][:3])
  307. prompt_speech_tokens_tensor = torch.nn.utils.rnn.pad_sequence(prompt_speech_tokens_list, batch_first=True, padding_value=0)
  308. cache = self.flow.setup_cache(
  309. prompt_speech_tokens_tensor.to(self.device),
  310. prompt_mels_for_flow.to(self.device),
  311. spk_emb_for_flow.to(self.device),
  312. n_timesteps=10
  313. )
  314. # Hack: this is a hack to avoid in-place changes to the cache['estimator_att_cache'] and cache['estimator_cnn_cache']
  315. cache['estimator_att_cache'] = cache['estimator_att_cache'].clone()
  316. cache['estimator_cnn_cache'] = cache['estimator_cnn_cache'].clone()
  317. return cache
  318. @torch.inference_mode()
  319. def forward_streaming(
  320. self, generated_speech_tokens: list[int], last_chunk: bool, request_id: str, speaker_id: str, prompt_audio: torch.Tensor = None, prompt_audio_sample_rate: int = 16000
  321. ):
  322. if speaker_id not in self.speaker_cache:
  323. assert prompt_audio is not None, "prompt_audio is required for new speaker"
  324. assert prompt_audio_sample_rate == 16000
  325. prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow = self.prepare_prompt_audio([prompt_audio], [prompt_audio_sample_rate])
  326. token_len = min(int(prompt_mels_for_flow.shape[1] / 2), len(prompt_speech_tokens_list[0]))
  327. prompt_mels_for_flow = prompt_mels_for_flow[:, :2 * token_len].contiguous()
  328. prompt_speech_tokens_list[0] = prompt_speech_tokens_list[0][:token_len]
  329. cache_dict = self.get_prompt_audio_cache_for_streaming_tts(prompt_speech_tokens_list, prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
  330. prompt_audio_dict = {'spk_emb_for_flow': spk_emb_for_flow, 'prompt_mels_for_flow': prompt_mels_for_flow}
  331. self.speaker_cache[speaker_id] = {'prompt_audio_dict': prompt_audio_dict, 'cache_dict': cache_dict}
  332. if request_id not in self.streaming_flow_cache:
  333. self.streaming_flow_cache[request_id] = {k: v.clone() for k, v in self.speaker_cache[speaker_id]['cache_dict'].items()}
  334. self.hift_cache_dict[request_id] = dict(
  335. mel = torch.zeros(1, 80, 0, device='cuda'),
  336. source = torch.zeros(1, 1, 0, device='cuda'),
  337. speech = torch.zeros(1, 0, device='cuda'),
  338. )
  339. current_request_cache = self.streaming_flow_cache[request_id]
  340. current_prompt_audio_dict = self.speaker_cache[speaker_id]['prompt_audio_dict']
  341. generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
  342. chunk_mel, new_streaming_flow_cache = self.flow.inference_chunk(
  343. token=generated_speech_tokens,
  344. spk=current_prompt_audio_dict['spk_emb_for_flow'].to(self.device),
  345. cache=current_request_cache,
  346. last_chunk=last_chunk,
  347. n_timesteps=10,
  348. )
  349. self.streaming_flow_cache[request_id] = new_streaming_flow_cache
  350. if self.streaming_flow_cache[request_id]['estimator_att_cache'].shape[4] > (current_prompt_audio_dict['prompt_mels_for_flow'].shape[1] + 100):
  351. self.streaming_flow_cache[request_id]['estimator_att_cache'] = torch.cat([
  352. self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, :current_prompt_audio_dict['prompt_mels_for_flow'].shape[1]],
  353. self.streaming_flow_cache[request_id]['estimator_att_cache'][:, :, :, :, -100:],
  354. ], dim=4)
  355. hift_cache_mel = self.hift_cache_dict[request_id]['mel']
  356. hift_cache_source = self.hift_cache_dict[request_id]['source']
  357. hift_cache_speech = self.hift_cache_dict[request_id]['speech']
  358. mel = torch.concat([hift_cache_mel, chunk_mel], dim=2)
  359. speech, source = self.hift(mel, hift_cache_source)
  360. # overlap speech smooth
  361. if hift_cache_speech.shape[-1] > 0:
  362. speech = fade_in_out(speech, hift_cache_speech, self.speech_window)
  363. # update vocoder cache
  364. self.hift_cache_dict[request_id] = dict(
  365. mel = mel[..., -self.mel_cache_len:].clone().detach(),
  366. source = source[:, :, -self.source_cache_len:].clone().detach(),
  367. speech = speech[:, -self.source_cache_len:].clone().detach(),
  368. )
  369. if not last_chunk:
  370. speech = speech[:, :-self.source_cache_len]
  371. if last_chunk:
  372. assert request_id in self.streaming_flow_cache
  373. self.streaming_flow_cache.pop(request_id)
  374. self.hift_cache_dict.pop(request_id)
  375. return speech
  376. def collate_fn(batch):
  377. ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
  378. for i, item in enumerate(batch):
  379. generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
  380. audio = torch.from_numpy(item['prompt_audio']['array']).float()
  381. prompt_audios_list.append(audio)
  382. prompt_audios_sample_rate.append(item['prompt_audio']['sampling_rate'])
  383. ids.append(item['id'])
  384. return ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate
  385. def get_args():
  386. parser = argparse.ArgumentParser()
  387. parser.add_argument("--enable-trt", action="store_true")
  388. parser.add_argument("--model-dir", type=str, default="./Step-Audio-2-mini/token2wav")
  389. parser.add_argument("--batch-size", type=int, default=1)
  390. parser.add_argument("--output-dir", type=str, default="generated_wavs")
  391. parser.add_argument("--huggingface-dataset-split", type=str, default="wenetspeech4tts")
  392. parser.add_argument("--warmup", type=int, default=3, help="Number of warmup epochs, performance statistics will only be collected from the last epoch")
  393. return parser.parse_args()
  394. if __name__ == "__main__":
  395. args = get_args()
  396. model = CosyVoice2_Token2Wav(model_dir=args.model_dir, enable_trt=args.enable_trt)
  397. # mkdir output_dir if not exists
  398. if not os.path.exists(args.output_dir):
  399. os.makedirs(args.output_dir)
  400. dataset_name = "yuekai/seed_tts_cosy2"
  401. dataset = load_dataset(dataset_name, split=args.huggingface_dataset_split, trust_remote_code=True)
  402. data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
  403. for epoch in range(args.warmup):
  404. start_time = time.time()
  405. for batch in data_loader:
  406. ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch
  407. generated_wavs = model(generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate)
  408. for id, wav in zip(ids, generated_wavs):
  409. torchaudio.save(f"{args.output_dir}/{id}.wav", wav.cpu(), 24000)
  410. end_time = time.time()
  411. epoch_time = end_time - start_time
  412. print(f"Measurement epoch time taken: {epoch_time:.4f} seconds")