token2wav_cosyvoice3.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. """ Example Usage
  2. CUDA_VISIBLE_DEVICES=0 \
  3. python3 token2wav_cosyvoice3.py --enable-trt || exit 1
  4. """
  5. import torch
  6. import torchaudio
  7. import torchaudio.compliance.kaldi as kaldi
  8. import onnxruntime
  9. import s3tokenizer
  10. import os
  11. import logging
  12. import argparse
  13. import queue
  14. import time
  15. import numpy as np
  16. from functools import partial
  17. from hyperpyyaml import load_hyperpyyaml
  18. from matcha.utils.audio import mel_spectrogram as matcha_mel_spectrogram
  19. from torch.utils.data import DataLoader
  20. from datasets import load_dataset
  21. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  22. logger = logging.getLogger(__name__)
  23. # CosyVoice3 mel params from cosyvoice3.yaml (fmax=None, NOT 8000)
  24. mel_spectrogram = partial(matcha_mel_spectrogram,
  25. n_fft=1920, num_mels=80, sampling_rate=24000,
  26. hop_size=480, win_size=1920, fmin=0, fmax=None, center=False)
  27. def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16, autocast_mode=False):
  28. import tensorrt as trt
  29. logging.info("Converting onnx to trt...")
  30. if autocast_mode:
  31. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
  32. else:
  33. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  34. logger = trt.Logger(trt.Logger.INFO)
  35. builder = trt.Builder(logger)
  36. network = builder.create_network(network_flags)
  37. parser = trt.OnnxParser(network, logger)
  38. config = builder.create_builder_config()
  39. config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
  40. if not autocast_mode:
  41. if fp16:
  42. config.set_flag(trt.BuilderFlag.FP16)
  43. profile = builder.create_optimization_profile()
  44. # load onnx model
  45. with open(onnx_model, "rb") as f:
  46. if not parser.parse(f.read()):
  47. for error in range(parser.num_errors):
  48. print(parser.get_error(error))
  49. raise ValueError('failed to parse {}'.format(onnx_model))
  50. # set input shapes
  51. for i in range(len(trt_kwargs['input_names'])):
  52. profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
  53. tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
  54. # set input and output data type
  55. for i in range(network.num_inputs):
  56. input_tensor = network.get_input(i)
  57. input_tensor.dtype = tensor_dtype
  58. for i in range(network.num_outputs):
  59. output_tensor = network.get_output(i)
  60. output_tensor.dtype = tensor_dtype
  61. config.add_optimization_profile(profile)
  62. engine_bytes = builder.build_serialized_network(network, config)
  63. # save trt engine
  64. with open(trt_model, "wb") as f:
  65. f.write(engine_bytes)
  66. logging.info("Succesfully convert onnx to trt...")
  67. class TrtContextWrapper:
  68. def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
  69. self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
  70. self.trt_engine = trt_engine
  71. self.device = device
  72. for _ in range(trt_concurrent):
  73. trt_context = trt_engine.create_execution_context()
  74. trt_stream = torch.cuda.stream(torch.cuda.Stream(torch.device(device)))
  75. assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
  76. self.trt_context_pool.put([trt_context, trt_stream])
  77. assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
  78. def acquire_estimator(self):
  79. return self.trt_context_pool.get(), self.trt_engine
  80. def release_estimator(self, context, stream):
  81. self.trt_context_pool.put([context, stream])
  82. class CosyVoice3_Token2Wav(torch.nn.Module):
  83. def __init__(self, model_dir, enable_trt=False, device_id=0, autocast_mode=True, streaming=False):
  84. super().__init__()
  85. self.device_id = device_id
  86. self.device = f"cuda:{device_id}"
  87. self.autocast_mode = autocast_mode
  88. self.streaming = streaming
  89. # Load flow and hift from cosyvoice3.yaml
  90. with open(f"{model_dir}/cosyvoice3.yaml", "r") as f:
  91. configs = load_hyperpyyaml(f, overrides={
  92. 'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')
  93. })
  94. self.flow = configs['flow']
  95. self.flow.load_state_dict(
  96. torch.load(f"{model_dir}/flow.pt", map_location="cpu", weights_only=True),
  97. strict=True
  98. )
  99. self.flow.to(self.device).eval()
  100. self.hift = configs['hift']
  101. hift_state_dict = {
  102. k.replace('generator.', ''): v
  103. for k, v in torch.load(f"{model_dir}/hift.pt", map_location="cpu", weights_only=True).items()
  104. }
  105. self.hift.load_state_dict(hift_state_dict, strict=True)
  106. self.hift.to(self.device).eval()
  107. # Speaker embedding model (campplus)
  108. option = onnxruntime.SessionOptions()
  109. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  110. option.intra_op_num_threads = 1
  111. self.spk_model = onnxruntime.InferenceSession(
  112. f"{model_dir}/campplus.onnx", sess_options=option,
  113. providers=["CPUExecutionProvider"]
  114. )
  115. # Audio tokenizer v3
  116. self.audio_tokenizer = s3tokenizer.load_model(
  117. f"{model_dir}/speech_tokenizer_v3.onnx"
  118. ).to(self.device).eval()
  119. self.fp16 = enable_trt
  120. if enable_trt:
  121. self.flow.half()
  122. self.load_trt(model_dir)
  123. self.load_spk_trt(model_dir)
  124. def load_trt(self, model_dir, trt_concurrent=1):
  125. streaming_prefix = 'streaming.' if self.streaming else ''
  126. if self.autocast_mode:
  127. onnx_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}autocast_fp16.onnx'
  128. trt_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}autocast_fp16.{self.device_id}.plan'
  129. else:
  130. onnx_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}fp32.onnx'
  131. trt_path = f'{model_dir}/flow.decoder.estimator.{streaming_prefix}fp32.{self.device_id}.plan'
  132. if not os.path.exists(trt_path) or os.path.getsize(trt_path) == 0:
  133. trt_kwargs = self.get_trt_kwargs()
  134. convert_onnx_to_trt(trt_path, trt_kwargs, onnx_path,
  135. fp16=True, autocast_mode=self.autocast_mode)
  136. del self.flow.decoder.estimator
  137. import tensorrt as trt
  138. with open(trt_path, 'rb') as f:
  139. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  140. assert estimator_engine is not None, 'failed to load trt {}'.format(trt_path)
  141. self.flow.decoder.estimator = TrtContextWrapper(
  142. estimator_engine, trt_concurrent=trt_concurrent, device=self.device
  143. )
  144. def get_trt_kwargs(self):
  145. # CosyVoice3 DiT estimator has 6 inputs: x, mask, mu, t, spks, cond
  146. # Only inputs with dynamic dims need optimization profiles.
  147. # t=[2(fixed)] and spks=[2(fixed),80(fixed)] are fully fixed, TRT infers from ONNX.
  148. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  149. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  150. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  151. input_names = ["x", "mask", "mu", "cond"]
  152. return {'min_shape': min_shape, 'opt_shape': opt_shape,
  153. 'max_shape': max_shape, 'input_names': input_names}
  154. def load_spk_trt(self, model_dir, trt_concurrent=1, fp16=False):
  155. spk_trt_path = f'{model_dir}/campplus.{self.device_id}.fp32.plan'
  156. spk_onnx_path = f'{model_dir}/campplus.onnx'
  157. if not os.path.exists(spk_trt_path) or os.path.getsize(spk_trt_path) == 0:
  158. trt_kwargs = self.get_spk_trt_kwargs()
  159. convert_onnx_to_trt(spk_trt_path, trt_kwargs, spk_onnx_path, fp16)
  160. import tensorrt as trt
  161. with open(spk_trt_path, 'rb') as f:
  162. spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  163. assert spk_engine is not None, 'failed to load trt {}'.format(spk_trt_path)
  164. self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
  165. def get_spk_trt_kwargs(self):
  166. min_shape = [(1, 4, 80)]
  167. opt_shape = [(1, 500, 80)]
  168. max_shape = [(1, 3000, 80)]
  169. input_names = ["input"]
  170. return {'min_shape': min_shape, 'opt_shape': opt_shape,
  171. 'max_shape': max_shape, 'input_names': input_names}
  172. def forward_spk_embedding(self, spk_feat):
  173. if isinstance(self.spk_model, onnxruntime.InferenceSession):
  174. return self.spk_model.run(
  175. None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
  176. )[0].flatten().tolist()
  177. else:
  178. [spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
  179. with torch.cuda.device(self.device_id):
  180. torch.cuda.current_stream().synchronize()
  181. spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
  182. batch_size = spk_feat.size(0)
  183. with stream:
  184. spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
  185. output_tensor = torch.empty((batch_size, 192), device=spk_feat.device)
  186. data_ptrs = [spk_feat.contiguous().data_ptr(),
  187. output_tensor.contiguous().data_ptr()]
  188. for i, j in enumerate(data_ptrs):
  189. spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
  190. assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
  191. torch.cuda.current_stream().synchronize()
  192. self.spk_model.release_estimator(spk_model, stream)
  193. return output_tensor.cpu().numpy().flatten().tolist()
  194. def prompt_audio_tokenization(self, prompt_audios_list):
  195. prompt_speech_tokens_list, prompt_speech_mels_list = [], []
  196. for audio in prompt_audios_list:
  197. assert len(audio.shape) == 1
  198. log_mel = s3tokenizer.log_mel_spectrogram(audio)
  199. prompt_speech_mels_list.append(log_mel)
  200. prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_speech_mels_list)
  201. prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
  202. prompt_mels_for_llm.to(self.device), prompt_mels_lens_for_llm.to(self.device)
  203. )
  204. for i in range(len(prompt_speech_tokens)):
  205. speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
  206. prompt_speech_tokens_list.append(speech_tokens_i)
  207. return prompt_speech_tokens_list
  208. def get_spk_emb(self, prompt_audios_list):
  209. spk_emb_for_flow = []
  210. for audio in prompt_audios_list:
  211. assert len(audio.shape) == 1
  212. spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
  213. spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
  214. spk_emb = self.forward_spk_embedding(spk_feat)
  215. spk_emb_for_flow.append(spk_emb)
  216. spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
  217. return spk_emb_for_flow
  218. def get_prompt_mels(self, prompt_audios_list, prompt_audios_sample_rate):
  219. prompt_mels_for_flow = []
  220. prompt_mels_lens_for_flow = []
  221. for audio, sample_rate in zip(prompt_audios_list, prompt_audios_sample_rate):
  222. assert len(audio.shape) == 1
  223. audio = audio.unsqueeze(0)
  224. if sample_rate != 24000:
  225. audio = torchaudio.transforms.Resample(
  226. orig_freq=sample_rate, new_freq=24000)(audio)
  227. # CosyVoice3: fmax=None (Nyquist), matching cosyvoice3.yaml
  228. mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, 80]
  229. prompt_mels_for_flow.append(mel)
  230. prompt_mels_lens_for_flow.append(mel.shape[0])
  231. prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(
  232. prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', 80]
  233. prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
  234. return prompt_mels_for_flow, prompt_mels_lens_for_flow
  235. def forward_flow(self, prompt_speech_tokens_list, generated_speech_tokens_list,
  236. prompt_mels_for_flow, prompt_mels_lens_for_flow,
  237. spk_emb_for_flow):
  238. batch_size = len(generated_speech_tokens_list)
  239. generated_mels_list = []
  240. # CausalMaskedDiffWithDiT.inference asserts batch_size==1, so iterate per-sample
  241. for i in range(batch_size):
  242. token = torch.tensor([generated_speech_tokens_list[i]]).to(self.device)
  243. token_len = torch.tensor([len(generated_speech_tokens_list[i])]).to(self.device)
  244. prompt_token = torch.tensor([prompt_speech_tokens_list[i]]).to(self.device)
  245. prompt_token_len = torch.tensor([len(prompt_speech_tokens_list[i])]).to(self.device)
  246. prompt_feat = prompt_mels_for_flow[i:i+1, :prompt_mels_lens_for_flow[i]].to(self.device)
  247. prompt_feat_len = prompt_mels_lens_for_flow[i:i+1].to(self.device)
  248. embedding = spk_emb_for_flow[i:i+1].to(self.device)
  249. # CausalMaskedDiffWithDiT.inference returns mel already without prompt portion
  250. with torch.cuda.amp.autocast(self.fp16):
  251. mel, _ = self.flow.inference(
  252. token=token,
  253. token_len=token_len,
  254. prompt_token=prompt_token,
  255. prompt_token_len=prompt_token_len,
  256. prompt_feat=prompt_feat,
  257. prompt_feat_len=prompt_feat_len,
  258. embedding=embedding,
  259. streaming=False,
  260. finalize=True
  261. )
  262. generated_mels_list.append(mel)
  263. return generated_mels_list
  264. def forward_hift(self, generated_mels_list):
  265. generated_wavs = []
  266. for mel in generated_mels_list:
  267. # CausalHiFTGenerator.inference with finalize=True
  268. wav, _ = self.hift.inference(speech_feat=mel, finalize=True)
  269. generated_wavs.append(wav)
  270. return generated_wavs
  271. def forward_stream(self, generated_speech_tokens, prompt_speech_tokens,
  272. prompt_feat, embedding,
  273. token_hop_len=25, stream_scale_factor=2, token_max_hop_len=100):
  274. """Streaming token2wav for a single sample: process tokens in chunks."""
  275. prompt_token = torch.tensor([prompt_speech_tokens]).to(self.device)
  276. prompt_token_len = torch.tensor([len(prompt_speech_tokens)]).to(self.device)
  277. prompt_feat = prompt_feat.to(self.device)
  278. prompt_feat_len = torch.tensor([prompt_feat.shape[1]]).to(self.device)
  279. embedding = embedding.to(self.device)
  280. pre_lookahead_len = self.flow.pre_lookahead_len
  281. token_mel_ratio = self.flow.token_mel_ratio
  282. # Align first chunk with hop_len boundary
  283. prompt_token_pad = int(
  284. np.ceil(prompt_token.shape[1] / token_hop_len) * token_hop_len
  285. - prompt_token.shape[1]
  286. )
  287. total_tokens = len(generated_speech_tokens)
  288. token_offset = 0
  289. current_hop = token_hop_len
  290. hift_cache_mel = None
  291. speech_offset = 0
  292. audio_chunks = []
  293. while token_offset < total_tokens:
  294. this_hop = current_hop + prompt_token_pad if token_offset == 0 else current_hop
  295. remaining = total_tokens - token_offset
  296. if remaining >= this_hop + pre_lookahead_len:
  297. end_idx = token_offset + this_hop + pre_lookahead_len
  298. this_token = torch.tensor([generated_speech_tokens[:end_idx]]).to(self.device)
  299. finalize = False
  300. else:
  301. this_token = torch.tensor([generated_speech_tokens]).to(self.device)
  302. finalize = True
  303. with torch.cuda.amp.autocast(self.fp16):
  304. mel, _ = self.flow.inference(
  305. token=this_token,
  306. token_len=torch.tensor([this_token.shape[1]]).to(self.device),
  307. prompt_token=prompt_token,
  308. prompt_token_len=prompt_token_len,
  309. prompt_feat=prompt_feat,
  310. prompt_feat_len=prompt_feat_len,
  311. embedding=embedding,
  312. streaming=True,
  313. finalize=finalize,
  314. )
  315. mel = mel[:, :, token_offset * token_mel_ratio:]
  316. if hift_cache_mel is not None:
  317. mel = torch.concat([hift_cache_mel, mel], dim=2)
  318. hift_cache_mel = mel
  319. tts_speech, _ = self.hift.inference(speech_feat=mel, finalize=finalize)
  320. tts_speech = tts_speech[:, speech_offset:]
  321. speech_offset += tts_speech.shape[1]
  322. logger.info(f"[stream] token_offset={token_offset}, this_hop={this_hop}, "
  323. f"mel_shape={mel.shape}, speech_len={tts_speech.shape[1]}, finalize={finalize}")
  324. audio_chunks.append(tts_speech)
  325. token_offset += this_hop
  326. if not finalize:
  327. current_hop = min(token_max_hop_len, current_hop * stream_scale_factor)
  328. else:
  329. break
  330. return torch.cat(audio_chunks, dim=1)
  331. @torch.inference_mode()
  332. def forward(self, generated_speech_tokens_list, prompt_audios_list,
  333. prompt_audios_sample_rate, streaming=False):
  334. assert all(sr == 16000 for sr in prompt_audios_sample_rate)
  335. prompt_speech_tokens_list = self.prompt_audio_tokenization(prompt_audios_list)
  336. prompt_mels_for_flow, prompt_mels_lens_for_flow = self.get_prompt_mels(
  337. prompt_audios_list, prompt_audios_sample_rate)
  338. spk_emb_for_flow = self.get_spk_emb(prompt_audios_list)
  339. # Align prompt_speech_feat and prompt_speech_token to exact 2:1 ratio
  340. # (matches frontend.frontend_zero_shot logic)
  341. for i in range(len(prompt_speech_tokens_list)):
  342. token_len = min(int(prompt_mels_lens_for_flow[i].item() / 2),
  343. len(prompt_speech_tokens_list[i]))
  344. prompt_speech_tokens_list[i] = prompt_speech_tokens_list[i][:token_len]
  345. prompt_mels_lens_for_flow[i] = 2 * token_len
  346. if streaming:
  347. generated_wavs = []
  348. for i in range(len(generated_speech_tokens_list)):
  349. prompt_feat = prompt_mels_for_flow[i:i+1, :prompt_mels_lens_for_flow[i]]
  350. embedding = spk_emb_for_flow[i:i+1]
  351. wav = self.forward_stream(
  352. generated_speech_tokens_list[i],
  353. prompt_speech_tokens_list[i],
  354. prompt_feat, embedding,
  355. )
  356. generated_wavs.append(wav)
  357. return generated_wavs
  358. generated_mels_list = self.forward_flow(
  359. prompt_speech_tokens_list, generated_speech_tokens_list,
  360. prompt_mels_for_flow, prompt_mels_lens_for_flow, spk_emb_for_flow)
  361. generated_wavs = self.forward_hift(generated_mels_list)
  362. return generated_wavs