generator.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """HIFI-GAN"""
  15. from typing import Dict, Optional, List
  16. import numpy as np
  17. from scipy.signal import get_window
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from torch.nn import Conv1d
  22. from torch.nn import ConvTranspose1d
  23. from torch.nn.utils import remove_weight_norm
  24. try:
  25. from torch.nn.utils.parametrizations import weight_norm
  26. except ImportError:
  27. from torch.nn.utils import weight_norm
  28. from torch.distributions.uniform import Uniform
  29. from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
  30. from cosyvoice.transformer.activation import Snake
  31. from cosyvoice.utils.common import get_padding
  32. from cosyvoice.utils.common import init_weights
  33. """hifigan based generator implementation.
  34. This code is modified from https://github.com/jik876/hifi-gan
  35. ,https://github.com/kan-bayashi/ParallelWaveGAN and
  36. https://github.com/NVIDIA/BigVGAN
  37. """
  38. class ResBlock(torch.nn.Module):
  39. """Residual block module in HiFiGAN/BigVGAN."""
  40. def __init__(
  41. self,
  42. channels: int = 512,
  43. kernel_size: int = 3,
  44. dilations: List[int] = [1, 3, 5],
  45. causal: bool = False,
  46. ):
  47. super(ResBlock, self).__init__()
  48. self.causal = causal
  49. self.convs1 = nn.ModuleList()
  50. self.convs2 = nn.ModuleList()
  51. for dilation in dilations:
  52. self.convs1.append(
  53. weight_norm(
  54. Conv1d(
  55. channels,
  56. channels,
  57. kernel_size,
  58. 1,
  59. dilation=dilation,
  60. padding=get_padding(kernel_size, dilation)) if causal is False else
  61. CausalConv1d(
  62. channels,
  63. channels,
  64. kernel_size,
  65. 1,
  66. dilation=dilation,
  67. causal_type='left'
  68. )
  69. )
  70. )
  71. self.convs2.append(
  72. weight_norm(
  73. Conv1d(
  74. channels,
  75. channels,
  76. kernel_size,
  77. 1,
  78. dilation=1,
  79. padding=get_padding(kernel_size, 1)) if causal is False else
  80. CausalConv1d(
  81. channels,
  82. channels,
  83. kernel_size,
  84. 1,
  85. dilation=1,
  86. causal_type='left'
  87. )
  88. )
  89. )
  90. self.convs1.apply(init_weights)
  91. self.convs2.apply(init_weights)
  92. self.activations1 = nn.ModuleList([
  93. Snake(channels, alpha_logscale=False)
  94. for _ in range(len(self.convs1))
  95. ])
  96. self.activations2 = nn.ModuleList([
  97. Snake(channels, alpha_logscale=False)
  98. for _ in range(len(self.convs2))
  99. ])
  100. def forward(self, x: torch.Tensor) -> torch.Tensor:
  101. for idx in range(len(self.convs1)):
  102. xt = self.activations1[idx](x)
  103. xt = self.convs1[idx](xt)
  104. xt = self.activations2[idx](xt)
  105. xt = self.convs2[idx](xt)
  106. x = xt + x
  107. return x
  108. def remove_weight_norm(self):
  109. for idx in range(len(self.convs1)):
  110. remove_weight_norm(self.convs1[idx])
  111. remove_weight_norm(self.convs2[idx])
  112. class SineGen(torch.nn.Module):
  113. """ Definition of sine generator
  114. SineGen(samp_rate, harmonic_num = 0,
  115. sine_amp = 0.1, noise_std = 0.003,
  116. voiced_threshold = 0,
  117. flag_for_pulse=False)
  118. samp_rate: sampling rate in Hz
  119. harmonic_num: number of harmonic overtones (default 0)
  120. sine_amp: amplitude of sine-wavefrom (default 0.1)
  121. noise_std: std of Gaussian noise (default 0.003)
  122. voiced_thoreshold: F0 threshold for U/V classification (default 0)
  123. flag_for_pulse: this SinGen is used inside PulseGen (default False)
  124. Note: when flag_for_pulse is True, the first time step of a voiced
  125. segment is always sin(np.pi) or cos(0)
  126. """
  127. def __init__(self, samp_rate, harmonic_num=0,
  128. sine_amp=0.1, noise_std=0.003,
  129. voiced_threshold=0):
  130. super(SineGen, self).__init__()
  131. self.sine_amp = sine_amp
  132. self.noise_std = noise_std
  133. self.harmonic_num = harmonic_num
  134. self.sampling_rate = samp_rate
  135. self.voiced_threshold = voiced_threshold
  136. def _f02uv(self, f0):
  137. # generate uv signal
  138. uv = (f0 > self.voiced_threshold).type(torch.float32)
  139. return uv
  140. @torch.no_grad()
  141. def forward(self, f0):
  142. """ sine_tensor, uv = forward(f0)
  143. input F0: tensor(batchsize=1, dim=1, length)
  144. f0 for unvoiced steps should be 0
  145. output sine_tensor: tensor(batchsize=1, length, dim)
  146. output uv: tensor(batchsize=1, length, 1)
  147. """
  148. f0 = f0.transpose(1, 2)
  149. F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
  150. for i in range(self.harmonic_num + 1):
  151. F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
  152. theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
  153. u_dist = Uniform(low=-np.pi, high=np.pi)
  154. phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
  155. phase_vec[:, 0, :] = 0
  156. # generate sine waveforms
  157. sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
  158. # generate uv signal
  159. uv = self._f02uv(f0)
  160. # noise: for unvoiced should be similar to sine_amp
  161. # std = self.sine_amp/3 -> max value ~ self.sine_amp
  162. # . for voiced regions is self.noise_std
  163. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  164. noise = noise_amp * torch.randn_like(sine_waves)
  165. # first: set the unvoiced part to 0 by uv
  166. # then: additive noise
  167. sine_waves = sine_waves * uv + noise
  168. return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
  169. class SineGen2(torch.nn.Module):
  170. """ Definition of sine generator
  171. SineGen(samp_rate, harmonic_num = 0,
  172. sine_amp = 0.1, noise_std = 0.003,
  173. voiced_threshold = 0,
  174. flag_for_pulse=False)
  175. samp_rate: sampling rate in Hz
  176. harmonic_num: number of harmonic overtones (default 0)
  177. sine_amp: amplitude of sine-wavefrom (default 0.1)
  178. noise_std: std of Gaussian noise (default 0.003)
  179. voiced_thoreshold: F0 threshold for U/V classification (default 0)
  180. flag_for_pulse: this SinGen is used inside PulseGen (default False)
  181. Note: when flag_for_pulse is True, the first time step of a voiced
  182. segment is always sin(np.pi) or cos(0)
  183. """
  184. def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
  185. sine_amp=0.1, noise_std=0.003,
  186. voiced_threshold=0,
  187. flag_for_pulse=False,
  188. causal=False):
  189. super(SineGen2, self).__init__()
  190. self.sine_amp = sine_amp
  191. self.noise_std = noise_std
  192. self.harmonic_num = harmonic_num
  193. self.dim = self.harmonic_num + 1
  194. self.sampling_rate = samp_rate
  195. self.voiced_threshold = voiced_threshold
  196. self.flag_for_pulse = flag_for_pulse
  197. self.upsample_scale = upsample_scale
  198. self.causal = causal
  199. if causal is True:
  200. self.rand_ini = torch.rand(1, 9)
  201. self.rand_ini[:, 0] = 0
  202. self.sine_waves = torch.rand(1, 300 * 24000, 9)
  203. def _f02uv(self, f0):
  204. # generate uv signal
  205. uv = (f0 > self.voiced_threshold).type(torch.float32)
  206. return uv
  207. def _f02sine(self, f0_values):
  208. """ f0_values: (batchsize, length, dim)
  209. where dim indicates fundamental tone and overtones
  210. """
  211. # convert to F0 in rad. The interger part n can be ignored
  212. # because 2 * np.pi * n doesn't affect phase
  213. rad_values = (f0_values / self.sampling_rate) % 1
  214. # initial phase noise (no noise for fundamental component)
  215. if self.training is False and self.causal is True:
  216. rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
  217. else:
  218. rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
  219. rand_ini[:, 0] = 0
  220. rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
  221. # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
  222. if not self.flag_for_pulse:
  223. rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
  224. scale_factor=1 / self.upsample_scale,
  225. mode="linear").transpose(1, 2)
  226. phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
  227. phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
  228. scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
  229. sines = torch.sin(phase)
  230. else:
  231. # If necessary, make sure that the first time step of every
  232. # voiced segments is sin(pi) or cos(0)
  233. # This is used for pulse-train generation
  234. # identify the last time step in unvoiced segments
  235. uv = self._f02uv(f0_values)
  236. uv_1 = torch.roll(uv, shifts=-1, dims=1)
  237. uv_1[:, -1, :] = 1
  238. u_loc = (uv < 1) * (uv_1 > 0)
  239. # get the instantanouse phase
  240. tmp_cumsum = torch.cumsum(rad_values, dim=1)
  241. # different batch needs to be processed differently
  242. for idx in range(f0_values.shape[0]):
  243. temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
  244. temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
  245. # stores the accumulation of i.phase within
  246. # each voiced segments
  247. tmp_cumsum[idx, :, :] = 0
  248. tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
  249. # rad_values - tmp_cumsum: remove the accumulation of i.phase
  250. # within the previous voiced segment.
  251. i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
  252. # get the sines
  253. sines = torch.cos(i_phase * 2 * np.pi)
  254. return sines
  255. def forward(self, f0):
  256. """ sine_tensor, uv = forward(f0)
  257. input F0: tensor(batchsize=1, length, dim=1)
  258. f0 for unvoiced steps should be 0
  259. output sine_tensor: tensor(batchsize=1, length, dim)
  260. output uv: tensor(batchsize=1, length, 1)
  261. """
  262. # fundamental component
  263. fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
  264. # generate sine waveforms
  265. sine_waves = self._f02sine(fn) * self.sine_amp
  266. # generate uv signal
  267. uv = self._f02uv(f0)
  268. # noise: for unvoiced should be similar to sine_amp
  269. # std = self.sine_amp/3 -> max value ~ self.sine_amp
  270. # . for voiced regions is self.noise_std
  271. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  272. if self.training is False and self.causal is True:
  273. noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
  274. else:
  275. noise = noise_amp * torch.randn_like(sine_waves)
  276. # first: set the unvoiced part to 0 by uv
  277. # then: additive noise
  278. sine_waves = sine_waves * uv + noise
  279. return sine_waves, uv, noise
  280. class SourceModuleHnNSF(torch.nn.Module):
  281. """ SourceModule for hn-nsf
  282. SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
  283. add_noise_std=0.003, voiced_threshod=0)
  284. sampling_rate: sampling_rate in Hz
  285. harmonic_num: number of harmonic above F0 (default: 0)
  286. sine_amp: amplitude of sine source signal (default: 0.1)
  287. add_noise_std: std of additive Gaussian noise (default: 0.003)
  288. note that amplitude of noise in unvoiced is decided
  289. by sine_amp
  290. voiced_threshold: threhold to set U/V given F0 (default: 0)
  291. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  292. F0_sampled (batchsize, length, 1)
  293. Sine_source (batchsize, length, 1)
  294. noise_source (batchsize, length 1)
  295. uv (batchsize, length, 1)
  296. """
  297. def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
  298. add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
  299. super(SourceModuleHnNSF, self).__init__()
  300. self.sine_amp = sine_amp
  301. self.noise_std = add_noise_std
  302. # to produce sine waveforms
  303. if sinegen_type == '1':
  304. self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
  305. else:
  306. self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
  307. # to merge source harmonics into a single excitation
  308. self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
  309. self.l_tanh = torch.nn.Tanh()
  310. self.causal = causal
  311. if causal is True:
  312. self.uv = torch.rand(1, 300 * 24000, 1)
  313. def forward(self, x):
  314. """
  315. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  316. F0_sampled (batchsize, length, 1)
  317. Sine_source (batchsize, length, 1)
  318. noise_source (batchsize, length 1)
  319. """
  320. # source for harmonic branch
  321. with torch.no_grad():
  322. sine_wavs, uv, _ = self.l_sin_gen(x)
  323. sine_merge = self.l_tanh(self.l_linear(sine_wavs))
  324. # source for noise branch, in the same shape as uv
  325. if self.training is False and self.causal is True:
  326. noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
  327. else:
  328. noise = torch.randn_like(uv) * self.sine_amp / 3
  329. return sine_merge, noise, uv
  330. class HiFTGenerator(nn.Module):
  331. """
  332. HiFTNet Generator: Neural Source Filter + ISTFTNet
  333. https://arxiv.org/abs/2309.09493
  334. """
  335. def __init__(
  336. self,
  337. in_channels: int = 80,
  338. base_channels: int = 512,
  339. nb_harmonics: int = 8,
  340. sampling_rate: int = 22050,
  341. nsf_alpha: float = 0.1,
  342. nsf_sigma: float = 0.003,
  343. nsf_voiced_threshold: float = 10,
  344. upsample_rates: List[int] = [8, 8],
  345. upsample_kernel_sizes: List[int] = [16, 16],
  346. istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
  347. resblock_kernel_sizes: List[int] = [3, 7, 11],
  348. resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  349. source_resblock_kernel_sizes: List[int] = [7, 11],
  350. source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
  351. lrelu_slope: float = 0.1,
  352. audio_limit: float = 0.99,
  353. f0_predictor: torch.nn.Module = None,
  354. ):
  355. super(HiFTGenerator, self).__init__()
  356. self.out_channels = 1
  357. self.nb_harmonics = nb_harmonics
  358. self.sampling_rate = sampling_rate
  359. self.istft_params = istft_params
  360. self.lrelu_slope = lrelu_slope
  361. self.audio_limit = audio_limit
  362. self.num_kernels = len(resblock_kernel_sizes)
  363. self.num_upsamples = len(upsample_rates)
  364. # NOTE in CosyVoice2, we use the original SineGen implementation
  365. self.m_source = SourceModuleHnNSF(
  366. sampling_rate=sampling_rate,
  367. upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
  368. harmonic_num=nb_harmonics,
  369. sine_amp=nsf_alpha,
  370. add_noise_std=nsf_sigma,
  371. voiced_threshod=nsf_voiced_threshold,
  372. sinegen_type='1' if self.sampling_rate == 22050 else '2',
  373. causal=False)
  374. self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
  375. self.conv_pre = weight_norm(
  376. Conv1d(in_channels, base_channels, 7, 1, padding=3)
  377. )
  378. # Up
  379. self.ups = nn.ModuleList()
  380. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  381. self.ups.append(
  382. weight_norm(
  383. ConvTranspose1d(
  384. base_channels // (2**i),
  385. base_channels // (2**(i + 1)),
  386. k,
  387. u,
  388. padding=(k - u) // 2,
  389. )
  390. )
  391. )
  392. # Down
  393. self.source_downs = nn.ModuleList()
  394. self.source_resblocks = nn.ModuleList()
  395. downsample_rates = [1] + upsample_rates[::-1][:-1]
  396. downsample_cum_rates = np.cumprod(downsample_rates)
  397. for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
  398. if u == 1:
  399. self.source_downs.append(
  400. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
  401. )
  402. else:
  403. self.source_downs.append(
  404. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
  405. )
  406. self.source_resblocks.append(
  407. ResBlock(base_channels // (2 ** (i + 1)), k, d)
  408. )
  409. self.resblocks = nn.ModuleList()
  410. for i in range(len(self.ups)):
  411. ch = base_channels // (2**(i + 1))
  412. for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
  413. self.resblocks.append(ResBlock(ch, k, d))
  414. self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
  415. self.ups.apply(init_weights)
  416. self.conv_post.apply(init_weights)
  417. self.reflection_pad = nn.ReflectionPad1d((1, 0))
  418. self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
  419. self.f0_predictor = f0_predictor
  420. def remove_weight_norm(self):
  421. print('Removing weight norm...')
  422. for l in self.ups:
  423. remove_weight_norm(l)
  424. for l in self.resblocks:
  425. l.remove_weight_norm()
  426. remove_weight_norm(self.conv_pre)
  427. remove_weight_norm(self.conv_post)
  428. self.m_source.remove_weight_norm()
  429. for l in self.source_downs:
  430. remove_weight_norm(l)
  431. for l in self.source_resblocks:
  432. l.remove_weight_norm()
  433. def _stft(self, x):
  434. spec = torch.stft(
  435. x,
  436. self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
  437. return_complex=True)
  438. spec = torch.view_as_real(spec) # [B, F, TT, 2]
  439. return spec[..., 0], spec[..., 1]
  440. def _istft(self, magnitude, phase):
  441. magnitude = torch.clip(magnitude, max=1e2)
  442. real = magnitude * torch.cos(phase)
  443. img = magnitude * torch.sin(phase)
  444. inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
  445. self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
  446. return inverse_transform
  447. def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
  448. s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
  449. s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
  450. x = self.conv_pre(x)
  451. for i in range(self.num_upsamples):
  452. x = F.leaky_relu(x, self.lrelu_slope)
  453. x = self.ups[i](x)
  454. if i == self.num_upsamples - 1:
  455. x = self.reflection_pad(x)
  456. # fusion
  457. si = self.source_downs[i](s_stft)
  458. si = self.source_resblocks[i](si)
  459. x = x + si
  460. xs = None
  461. for j in range(self.num_kernels):
  462. if xs is None:
  463. xs = self.resblocks[i * self.num_kernels + j](x)
  464. else:
  465. xs += self.resblocks[i * self.num_kernels + j](x)
  466. x = xs / self.num_kernels
  467. x = F.leaky_relu(x)
  468. x = self.conv_post(x)
  469. magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
  470. phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
  471. x = self._istft(magnitude, phase)
  472. x = torch.clamp(x, -self.audio_limit, self.audio_limit)
  473. return x
  474. def forward(
  475. self,
  476. batch: dict,
  477. device: torch.device,
  478. ) -> Dict[str, Optional[torch.Tensor]]:
  479. speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
  480. # mel->f0
  481. f0 = self.f0_predictor(speech_feat)
  482. # f0->source
  483. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  484. s, _, _ = self.m_source(s)
  485. s = s.transpose(1, 2)
  486. # mel+source->speech
  487. generated_speech = self.decode(x=speech_feat, s=s)
  488. return generated_speech, f0
  489. @torch.inference_mode()
  490. def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
  491. # mel->f0
  492. f0 = self.f0_predictor(speech_feat)
  493. # f0->source
  494. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  495. s, _, _ = self.m_source(s)
  496. s = s.transpose(1, 2)
  497. # use cache_source to avoid glitch
  498. if cache_source.shape[2] != 0:
  499. s[:, :, :cache_source.shape[2]] = cache_source
  500. generated_speech = self.decode(x=speech_feat, s=s)
  501. return generated_speech, s
  502. class CausalHiFTGenerator(HiFTGenerator):
  503. """
  504. HiFTNet Generator: Neural Source Filter + ISTFTNet
  505. https://arxiv.org/abs/2309.09493
  506. """
  507. def __init__(
  508. self,
  509. in_channels: int = 80,
  510. base_channels: int = 512,
  511. nb_harmonics: int = 8,
  512. sampling_rate: int = 22050,
  513. nsf_alpha: float = 0.1,
  514. nsf_sigma: float = 0.003,
  515. nsf_voiced_threshold: float = 10,
  516. upsample_rates: List[int] = [8, 8],
  517. upsample_kernel_sizes: List[int] = [16, 16],
  518. istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
  519. resblock_kernel_sizes: List[int] = [3, 7, 11],
  520. resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  521. source_resblock_kernel_sizes: List[int] = [7, 11],
  522. source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
  523. lrelu_slope: float = 0.1,
  524. audio_limit: float = 0.99,
  525. conv_pre_look_right: int = 4,
  526. f0_predictor: torch.nn.Module = None,
  527. ):
  528. torch.nn.Module.__init__(self)
  529. self.out_channels = 1
  530. self.nb_harmonics = nb_harmonics
  531. self.sampling_rate = sampling_rate
  532. self.istft_params = istft_params
  533. self.lrelu_slope = lrelu_slope
  534. self.audio_limit = audio_limit
  535. self.num_kernels = len(resblock_kernel_sizes)
  536. self.num_upsamples = len(upsample_rates)
  537. self.m_source = SourceModuleHnNSF(
  538. sampling_rate=sampling_rate,
  539. upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
  540. harmonic_num=nb_harmonics,
  541. sine_amp=nsf_alpha,
  542. add_noise_std=nsf_sigma,
  543. voiced_threshod=nsf_voiced_threshold,
  544. sinegen_type='1' if self.sampling_rate == 22050 else '2',
  545. causal=True)
  546. self.upsample_rates = upsample_rates
  547. self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
  548. self.conv_pre = weight_norm(
  549. CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
  550. )
  551. # Up
  552. self.ups = nn.ModuleList()
  553. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  554. self.ups.append(
  555. weight_norm(
  556. CausalConv1dUpsample(
  557. base_channels // (2**i),
  558. base_channels // (2**(i + 1)),
  559. k,
  560. u,
  561. )
  562. )
  563. )
  564. # Down
  565. self.source_downs = nn.ModuleList()
  566. self.source_resblocks = nn.ModuleList()
  567. downsample_rates = [1] + upsample_rates[::-1][:-1]
  568. downsample_cum_rates = np.cumprod(downsample_rates)
  569. for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
  570. if u == 1:
  571. self.source_downs.append(
  572. CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
  573. )
  574. else:
  575. self.source_downs.append(
  576. CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
  577. )
  578. self.source_resblocks.append(
  579. ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
  580. )
  581. self.resblocks = nn.ModuleList()
  582. for i in range(len(self.ups)):
  583. ch = base_channels // (2**(i + 1))
  584. for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
  585. self.resblocks.append(ResBlock(ch, k, d, causal=True))
  586. self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
  587. self.ups.apply(init_weights)
  588. self.conv_post.apply(init_weights)
  589. self.reflection_pad = nn.ReflectionPad1d((1, 0))
  590. self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
  591. self.conv_pre_look_right = conv_pre_look_right
  592. self.f0_predictor = f0_predictor
  593. def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
  594. s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
  595. if finalize is True:
  596. x = self.conv_pre(x)
  597. else:
  598. x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
  599. s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
  600. s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
  601. s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
  602. for i in range(self.num_upsamples):
  603. x = F.leaky_relu(x, self.lrelu_slope)
  604. x = self.ups[i](x)
  605. if i == self.num_upsamples - 1:
  606. x = self.reflection_pad(x)
  607. # fusion
  608. si = self.source_downs[i](s_stft)
  609. si = self.source_resblocks[i](si)
  610. x = x + si
  611. xs = None
  612. for j in range(self.num_kernels):
  613. if xs is None:
  614. xs = self.resblocks[i * self.num_kernels + j](x)
  615. else:
  616. xs += self.resblocks[i * self.num_kernels + j](x)
  617. x = xs / self.num_kernels
  618. x = F.leaky_relu(x)
  619. x = self.conv_post(x)
  620. magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
  621. phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
  622. x = self._istft(magnitude, phase)
  623. if finalize is False:
  624. x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
  625. x = torch.clamp(x, -self.audio_limit, self.audio_limit)
  626. return x
  627. @torch.inference_mode()
  628. def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
  629. # mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
  630. self.f0_predictor.to('cpu')
  631. f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
  632. # f0->source
  633. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  634. s, _, _ = self.m_source(s)
  635. s = s.transpose(1, 2)
  636. if finalize is True:
  637. generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
  638. else:
  639. generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
  640. return generated_speech, s
  641. if __name__ == '__main__':
  642. torch.backends.cudnn.deterministic = True
  643. torch.backends.cudnn.benchmark = False
  644. from hyperpyyaml import load_hyperpyyaml
  645. with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
  646. configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
  647. model = configs['hift']
  648. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  649. model.to(device)
  650. model.eval()
  651. max_len, chunk_size, context_size = 300, 30, 8
  652. mel = torch.rand(1, 80, max_len).to(device)
  653. pred_gt, _ = model.inference(mel)
  654. for i in range(0, max_len, chunk_size):
  655. finalize = True if i + chunk_size + context_size >= max_len else False
  656. pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
  657. pred_chunk = pred_chunk[:, i * 480:]
  658. print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())