generator.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """HIFI-GAN"""
  15. from typing import Dict, Optional, List
  16. import numpy as np
  17. from scipy.signal import get_window
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from torch.nn import Conv1d
  22. from torch.nn import ConvTranspose1d
  23. from torch.nn.utils import remove_weight_norm
  24. try:
  25. from torch.nn.utils.parametrizations import weight_norm
  26. except ImportError:
  27. from torch.nn.utils import weight_norm
  28. from torch.distributions.uniform import Uniform
  29. from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
  30. from cosyvoice.transformer.activation import Snake
  31. from cosyvoice.utils.common import get_padding
  32. from cosyvoice.utils.common import init_weights
  33. """hifigan based generator implementation.
  34. This code is modified from https://github.com/jik876/hifi-gan
  35. ,https://github.com/kan-bayashi/ParallelWaveGAN and
  36. https://github.com/NVIDIA/BigVGAN
  37. """
  38. class ResBlock(torch.nn.Module):
  39. """Residual block module in HiFiGAN/BigVGAN."""
  40. def __init__(
  41. self,
  42. channels: int = 512,
  43. kernel_size: int = 3,
  44. dilations: List[int] = [1, 3, 5],
  45. causal: bool = False,
  46. ):
  47. super(ResBlock, self).__init__()
  48. self.causal = causal
  49. self.convs1 = nn.ModuleList()
  50. self.convs2 = nn.ModuleList()
  51. for dilation in dilations:
  52. self.convs1.append(
  53. weight_norm(
  54. Conv1d(
  55. channels,
  56. channels,
  57. kernel_size,
  58. 1,
  59. dilation=dilation,
  60. padding=get_padding(kernel_size, dilation)) if causal is False else
  61. CausalConv1d(
  62. channels,
  63. channels,
  64. kernel_size,
  65. 1,
  66. dilation=dilation,
  67. causal_type='left'
  68. )
  69. )
  70. )
  71. self.convs2.append(
  72. weight_norm(
  73. Conv1d(
  74. channels,
  75. channels,
  76. kernel_size,
  77. 1,
  78. dilation=1,
  79. padding=get_padding(kernel_size, 1)) if causal is False else
  80. CausalConv1d(
  81. channels,
  82. channels,
  83. kernel_size,
  84. 1,
  85. dilation=1,
  86. causal_type='left'
  87. )
  88. )
  89. )
  90. self.convs1.apply(init_weights)
  91. self.convs2.apply(init_weights)
  92. self.activations1 = nn.ModuleList([
  93. Snake(channels, alpha_logscale=False)
  94. for _ in range(len(self.convs1))
  95. ])
  96. self.activations2 = nn.ModuleList([
  97. Snake(channels, alpha_logscale=False)
  98. for _ in range(len(self.convs2))
  99. ])
  100. def forward(self, x: torch.Tensor) -> torch.Tensor:
  101. for idx in range(len(self.convs1)):
  102. xt = self.activations1[idx](x)
  103. xt = self.convs1[idx](xt)
  104. xt = self.activations2[idx](xt)
  105. xt = self.convs2[idx](xt)
  106. x = xt + x
  107. return x
  108. def remove_weight_norm(self):
  109. for idx in range(len(self.convs1)):
  110. remove_weight_norm(self.convs1[idx])
  111. remove_weight_norm(self.convs2[idx])
  112. class SineGen(torch.nn.Module):
  113. """ Definition of sine generator
  114. SineGen(samp_rate, harmonic_num = 0,
  115. sine_amp = 0.1, noise_std = 0.003,
  116. voiced_threshold = 0,
  117. flag_for_pulse=False)
  118. samp_rate: sampling rate in Hz
  119. harmonic_num: number of harmonic overtones (default 0)
  120. sine_amp: amplitude of sine-wavefrom (default 0.1)
  121. noise_std: std of Gaussian noise (default 0.003)
  122. voiced_thoreshold: F0 threshold for U/V classification (default 0)
  123. flag_for_pulse: this SinGen is used inside PulseGen (default False)
  124. Note: when flag_for_pulse is True, the first time step of a voiced
  125. segment is always sin(np.pi) or cos(0)
  126. """
  127. def __init__(self, samp_rate, harmonic_num=0,
  128. sine_amp=0.1, noise_std=0.003,
  129. voiced_threshold=0):
  130. super(SineGen, self).__init__()
  131. self.sine_amp = sine_amp
  132. self.noise_std = noise_std
  133. self.harmonic_num = harmonic_num
  134. self.sampling_rate = samp_rate
  135. self.voiced_threshold = voiced_threshold
  136. def _f02uv(self, f0):
  137. # generate uv signal
  138. uv = (f0 > self.voiced_threshold).type(torch.float32)
  139. return uv
  140. @torch.no_grad()
  141. def forward(self, f0):
  142. """
  143. :param f0: [B, 1, sample_len], Hz
  144. :return: [B, 1, sample_len]
  145. """
  146. F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
  147. for i in range(self.harmonic_num + 1):
  148. F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
  149. theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
  150. u_dist = Uniform(low=-np.pi, high=np.pi)
  151. phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
  152. phase_vec[:, 0, :] = 0
  153. # generate sine waveforms
  154. sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
  155. # generate uv signal
  156. uv = self._f02uv(f0)
  157. # noise: for unvoiced should be similar to sine_amp
  158. # std = self.sine_amp/3 -> max value ~ self.sine_amp
  159. # . for voiced regions is self.noise_std
  160. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  161. noise = noise_amp * torch.randn_like(sine_waves)
  162. # first: set the unvoiced part to 0 by uv
  163. # then: additive noise
  164. sine_waves = sine_waves * uv + noise
  165. return sine_waves, uv, noise
  166. class SineGen2(torch.nn.Module):
  167. """ Definition of sine generator
  168. SineGen(samp_rate, harmonic_num = 0,
  169. sine_amp = 0.1, noise_std = 0.003,
  170. voiced_threshold = 0,
  171. flag_for_pulse=False)
  172. samp_rate: sampling rate in Hz
  173. harmonic_num: number of harmonic overtones (default 0)
  174. sine_amp: amplitude of sine-wavefrom (default 0.1)
  175. noise_std: std of Gaussian noise (default 0.003)
  176. voiced_thoreshold: F0 threshold for U/V classification (default 0)
  177. flag_for_pulse: this SinGen is used inside PulseGen (default False)
  178. Note: when flag_for_pulse is True, the first time step of a voiced
  179. segment is always sin(np.pi) or cos(0)
  180. """
  181. def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
  182. sine_amp=0.1, noise_std=0.003,
  183. voiced_threshold=0,
  184. flag_for_pulse=False,
  185. causal=False):
  186. super(SineGen2, self).__init__()
  187. self.sine_amp = sine_amp
  188. self.noise_std = noise_std
  189. self.harmonic_num = harmonic_num
  190. self.dim = self.harmonic_num + 1
  191. self.sampling_rate = samp_rate
  192. self.voiced_threshold = voiced_threshold
  193. self.flag_for_pulse = flag_for_pulse
  194. self.upsample_scale = upsample_scale
  195. self.causal = causal
  196. if causal is True:
  197. self.rand_ini = torch.rand(1, 9)
  198. self.rand_ini[:, 0] = 0
  199. self.sine_waves = torch.rand(1, 60 * 16000, 9)
  200. def _f02uv(self, f0):
  201. # generate uv signal
  202. uv = (f0 > self.voiced_threshold).type(torch.float32)
  203. return uv
  204. def _f02sine(self, f0_values):
  205. """ f0_values: (batchsize, length, dim)
  206. where dim indicates fundamental tone and overtones
  207. """
  208. # convert to F0 in rad. The interger part n can be ignored
  209. # because 2 * np.pi * n doesn't affect phase
  210. rad_values = (f0_values / self.sampling_rate) % 1
  211. # initial phase noise (no noise for fundamental component)
  212. if self.training is False and self.causal is True:
  213. rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
  214. else:
  215. rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
  216. rand_ini[:, 0] = 0
  217. rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
  218. # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
  219. if not self.flag_for_pulse:
  220. rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
  221. scale_factor=1 / self.upsample_scale,
  222. mode="linear").transpose(1, 2)
  223. phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
  224. phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
  225. scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
  226. sines = torch.sin(phase)
  227. else:
  228. # If necessary, make sure that the first time step of every
  229. # voiced segments is sin(pi) or cos(0)
  230. # This is used for pulse-train generation
  231. # identify the last time step in unvoiced segments
  232. uv = self._f02uv(f0_values)
  233. uv_1 = torch.roll(uv, shifts=-1, dims=1)
  234. uv_1[:, -1, :] = 1
  235. u_loc = (uv < 1) * (uv_1 > 0)
  236. # get the instantanouse phase
  237. tmp_cumsum = torch.cumsum(rad_values, dim=1)
  238. # different batch needs to be processed differently
  239. for idx in range(f0_values.shape[0]):
  240. temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
  241. temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
  242. # stores the accumulation of i.phase within
  243. # each voiced segments
  244. tmp_cumsum[idx, :, :] = 0
  245. tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
  246. # rad_values - tmp_cumsum: remove the accumulation of i.phase
  247. # within the previous voiced segment.
  248. i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
  249. # get the sines
  250. sines = torch.cos(i_phase * 2 * np.pi)
  251. return sines
  252. def forward(self, f0):
  253. """ sine_tensor, uv = forward(f0)
  254. input F0: tensor(batchsize=1, length, dim=1)
  255. f0 for unvoiced steps should be 0
  256. output sine_tensor: tensor(batchsize=1, length, dim)
  257. output uv: tensor(batchsize=1, length, 1)
  258. """
  259. # fundamental component
  260. fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
  261. # generate sine waveforms
  262. sine_waves = self._f02sine(fn) * self.sine_amp
  263. # generate uv signal
  264. uv = self._f02uv(f0)
  265. # noise: for unvoiced should be similar to sine_amp
  266. # std = self.sine_amp/3 -> max value ~ self.sine_amp
  267. # . for voiced regions is self.noise_std
  268. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  269. if self.training is False and self.causal is True:
  270. noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
  271. else:
  272. noise = noise_amp * torch.randn_like(sine_waves)
  273. # first: set the unvoiced part to 0 by uv
  274. # then: additive noise
  275. sine_waves = sine_waves * uv + noise
  276. return sine_waves, uv, noise
  277. class SourceModuleHnNSF(torch.nn.Module):
  278. """ SourceModule for hn-nsf
  279. SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
  280. add_noise_std=0.003, voiced_threshod=0)
  281. sampling_rate: sampling_rate in Hz
  282. harmonic_num: number of harmonic above F0 (default: 0)
  283. sine_amp: amplitude of sine source signal (default: 0.1)
  284. add_noise_std: std of additive Gaussian noise (default: 0.003)
  285. note that amplitude of noise in unvoiced is decided
  286. by sine_amp
  287. voiced_threshold: threhold to set U/V given F0 (default: 0)
  288. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  289. F0_sampled (batchsize, length, 1)
  290. Sine_source (batchsize, length, 1)
  291. noise_source (batchsize, length 1)
  292. uv (batchsize, length, 1)
  293. """
  294. def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
  295. add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
  296. super(SourceModuleHnNSF, self).__init__()
  297. self.sine_amp = sine_amp
  298. self.noise_std = add_noise_std
  299. # to produce sine waveforms
  300. if sinegen_type == '1':
  301. self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
  302. else:
  303. self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
  304. # to merge source harmonics into a single excitation
  305. self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
  306. self.l_tanh = torch.nn.Tanh()
  307. self.causal = causal
  308. if causal is True:
  309. self.uv = torch.rand(1, 60 * 24000, 1)
  310. def forward(self, x):
  311. """
  312. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  313. F0_sampled (batchsize, length, 1)
  314. Sine_source (batchsize, length, 1)
  315. noise_source (batchsize, length 1)
  316. """
  317. # source for harmonic branch
  318. with torch.no_grad():
  319. sine_wavs, uv, _ = self.l_sin_gen(x)
  320. sine_merge = self.l_tanh(self.l_linear(sine_wavs))
  321. # source for noise branch, in the same shape as uv
  322. if self.training is False and self.causal is True:
  323. noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
  324. else:
  325. noise = torch.randn_like(uv) * self.sine_amp / 3
  326. return sine_merge, noise, uv
  327. class HiFTGenerator(nn.Module):
  328. """
  329. HiFTNet Generator: Neural Source Filter + ISTFTNet
  330. https://arxiv.org/abs/2309.09493
  331. """
  332. def __init__(
  333. self,
  334. in_channels: int = 80,
  335. base_channels: int = 512,
  336. nb_harmonics: int = 8,
  337. sampling_rate: int = 22050,
  338. nsf_alpha: float = 0.1,
  339. nsf_sigma: float = 0.003,
  340. nsf_voiced_threshold: float = 10,
  341. upsample_rates: List[int] = [8, 8],
  342. upsample_kernel_sizes: List[int] = [16, 16],
  343. istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
  344. resblock_kernel_sizes: List[int] = [3, 7, 11],
  345. resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  346. source_resblock_kernel_sizes: List[int] = [7, 11],
  347. source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
  348. lrelu_slope: float = 0.1,
  349. audio_limit: float = 0.99,
  350. f0_predictor: torch.nn.Module = None,
  351. ):
  352. super(HiFTGenerator, self).__init__()
  353. self.out_channels = 1
  354. self.nb_harmonics = nb_harmonics
  355. self.sampling_rate = sampling_rate
  356. self.istft_params = istft_params
  357. self.lrelu_slope = lrelu_slope
  358. self.audio_limit = audio_limit
  359. self.num_kernels = len(resblock_kernel_sizes)
  360. self.num_upsamples = len(upsample_rates)
  361. # NOTE in CosyVoice2, we use the original SineGen implementation
  362. self.m_source = SourceModuleHnNSF(
  363. sampling_rate=sampling_rate,
  364. upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
  365. harmonic_num=nb_harmonics,
  366. sine_amp=nsf_alpha,
  367. add_noise_std=nsf_sigma,
  368. voiced_threshod=nsf_voiced_threshold,
  369. sinegen_type='1' if self.sampling_rate == 22050 else '2',
  370. causal=False)
  371. self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
  372. self.conv_pre = weight_norm(
  373. Conv1d(in_channels, base_channels, 7, 1, padding=3)
  374. )
  375. # Up
  376. self.ups = nn.ModuleList()
  377. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  378. self.ups.append(
  379. weight_norm(
  380. ConvTranspose1d(
  381. base_channels // (2**i),
  382. base_channels // (2**(i + 1)),
  383. k,
  384. u,
  385. padding=(k - u) // 2,
  386. )
  387. )
  388. )
  389. # Down
  390. self.source_downs = nn.ModuleList()
  391. self.source_resblocks = nn.ModuleList()
  392. downsample_rates = [1] + upsample_rates[::-1][:-1]
  393. downsample_cum_rates = np.cumprod(downsample_rates)
  394. for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
  395. if u == 1:
  396. self.source_downs.append(
  397. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
  398. )
  399. else:
  400. self.source_downs.append(
  401. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
  402. )
  403. self.source_resblocks.append(
  404. ResBlock(base_channels // (2 ** (i + 1)), k, d)
  405. )
  406. self.resblocks = nn.ModuleList()
  407. for i in range(len(self.ups)):
  408. ch = base_channels // (2**(i + 1))
  409. for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
  410. self.resblocks.append(ResBlock(ch, k, d))
  411. self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
  412. self.ups.apply(init_weights)
  413. self.conv_post.apply(init_weights)
  414. self.reflection_pad = nn.ReflectionPad1d((1, 0))
  415. self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
  416. self.f0_predictor = f0_predictor
  417. def remove_weight_norm(self):
  418. print('Removing weight norm...')
  419. for l in self.ups:
  420. remove_weight_norm(l)
  421. for l in self.resblocks:
  422. l.remove_weight_norm()
  423. remove_weight_norm(self.conv_pre)
  424. remove_weight_norm(self.conv_post)
  425. self.m_source.remove_weight_norm()
  426. for l in self.source_downs:
  427. remove_weight_norm(l)
  428. for l in self.source_resblocks:
  429. l.remove_weight_norm()
  430. def _stft(self, x):
  431. spec = torch.stft(
  432. x,
  433. self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
  434. return_complex=True)
  435. spec = torch.view_as_real(spec) # [B, F, TT, 2]
  436. return spec[..., 0], spec[..., 1]
  437. def _istft(self, magnitude, phase):
  438. magnitude = torch.clip(magnitude, max=1e2)
  439. real = magnitude * torch.cos(phase)
  440. img = magnitude * torch.sin(phase)
  441. inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
  442. self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
  443. return inverse_transform
  444. def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
  445. s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
  446. s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
  447. x = self.conv_pre(x)
  448. for i in range(self.num_upsamples):
  449. x = F.leaky_relu(x, self.lrelu_slope)
  450. x = self.ups[i](x)
  451. if i == self.num_upsamples - 1:
  452. x = self.reflection_pad(x)
  453. # fusion
  454. si = self.source_downs[i](s_stft)
  455. si = self.source_resblocks[i](si)
  456. x = x + si
  457. xs = None
  458. for j in range(self.num_kernels):
  459. if xs is None:
  460. xs = self.resblocks[i * self.num_kernels + j](x)
  461. else:
  462. xs += self.resblocks[i * self.num_kernels + j](x)
  463. x = xs / self.num_kernels
  464. x = F.leaky_relu(x)
  465. x = self.conv_post(x)
  466. magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
  467. phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
  468. x = self._istft(magnitude, phase)
  469. x = torch.clamp(x, -self.audio_limit, self.audio_limit)
  470. return x
  471. def forward(
  472. self,
  473. batch: dict,
  474. device: torch.device,
  475. ) -> Dict[str, Optional[torch.Tensor]]:
  476. speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
  477. # mel->f0
  478. f0 = self.f0_predictor(speech_feat)
  479. # f0->source
  480. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  481. s, _, _ = self.m_source(s)
  482. s = s.transpose(1, 2)
  483. # mel+source->speech
  484. generated_speech = self.decode(x=speech_feat, s=s)
  485. return generated_speech, f0
  486. @torch.inference_mode()
  487. def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
  488. # mel->f0
  489. f0 = self.f0_predictor(speech_feat)
  490. # f0->source
  491. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  492. s, _, _ = self.m_source(s)
  493. s = s.transpose(1, 2)
  494. # use cache_source to avoid glitch
  495. if cache_source.shape[2] != 0:
  496. s[:, :, :cache_source.shape[2]] = cache_source
  497. generated_speech = self.decode(x=speech_feat, s=s)
  498. return generated_speech, s
  499. class CausalHiFTGenerator(HiFTGenerator):
  500. """
  501. HiFTNet Generator: Neural Source Filter + ISTFTNet
  502. https://arxiv.org/abs/2309.09493
  503. """
  504. def __init__(
  505. self,
  506. in_channels: int = 80,
  507. base_channels: int = 512,
  508. nb_harmonics: int = 8,
  509. sampling_rate: int = 22050,
  510. nsf_alpha: float = 0.1,
  511. nsf_sigma: float = 0.003,
  512. nsf_voiced_threshold: float = 10,
  513. upsample_rates: List[int] = [8, 8],
  514. upsample_kernel_sizes: List[int] = [16, 16],
  515. istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
  516. resblock_kernel_sizes: List[int] = [3, 7, 11],
  517. resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  518. source_resblock_kernel_sizes: List[int] = [7, 11],
  519. source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
  520. lrelu_slope: float = 0.1,
  521. audio_limit: float = 0.99,
  522. conv_pre_look_right: int = 4,
  523. f0_predictor: torch.nn.Module = None,
  524. ):
  525. torch.nn.Module.__init__(self)
  526. self.out_channels = 1
  527. self.nb_harmonics = nb_harmonics
  528. self.sampling_rate = sampling_rate
  529. self.istft_params = istft_params
  530. self.lrelu_slope = lrelu_slope
  531. self.audio_limit = audio_limit
  532. self.num_kernels = len(resblock_kernel_sizes)
  533. self.num_upsamples = len(upsample_rates)
  534. self.m_source = SourceModuleHnNSF(
  535. sampling_rate=sampling_rate,
  536. upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
  537. harmonic_num=nb_harmonics,
  538. sine_amp=nsf_alpha,
  539. add_noise_std=nsf_sigma,
  540. voiced_threshod=nsf_voiced_threshold,
  541. sinegen_type='1' if self.sampling_rate == 22050 else '2',
  542. causal=True)
  543. self.upsample_rates = upsample_rates
  544. self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
  545. self.conv_pre = weight_norm(
  546. CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
  547. )
  548. # Up
  549. self.ups = nn.ModuleList()
  550. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  551. self.ups.append(
  552. weight_norm(
  553. CausalConv1dUpsample(
  554. base_channels // (2**i),
  555. base_channels // (2**(i + 1)),
  556. k,
  557. u,
  558. )
  559. )
  560. )
  561. # Down
  562. self.source_downs = nn.ModuleList()
  563. self.source_resblocks = nn.ModuleList()
  564. downsample_rates = [1] + upsample_rates[::-1][:-1]
  565. downsample_cum_rates = np.cumprod(downsample_rates)
  566. for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
  567. if u == 1:
  568. self.source_downs.append(
  569. CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
  570. )
  571. else:
  572. self.source_downs.append(
  573. CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
  574. )
  575. self.source_resblocks.append(
  576. ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
  577. )
  578. self.resblocks = nn.ModuleList()
  579. for i in range(len(self.ups)):
  580. ch = base_channels // (2**(i + 1))
  581. for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
  582. self.resblocks.append(ResBlock(ch, k, d, causal=True))
  583. self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
  584. self.ups.apply(init_weights)
  585. self.conv_post.apply(init_weights)
  586. self.reflection_pad = nn.ReflectionPad1d((1, 0))
  587. self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
  588. self.conv_pre_look_right = conv_pre_look_right
  589. self.f0_predictor = f0_predictor
  590. def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
  591. s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
  592. if finalize is True:
  593. x = self.conv_pre(x)
  594. else:
  595. x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
  596. s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
  597. s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
  598. s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
  599. for i in range(self.num_upsamples):
  600. x = F.leaky_relu(x, self.lrelu_slope)
  601. x = self.ups[i](x)
  602. if i == self.num_upsamples - 1:
  603. x = self.reflection_pad(x)
  604. # fusion
  605. si = self.source_downs[i](s_stft)
  606. si = self.source_resblocks[i](si)
  607. x = x + si
  608. xs = None
  609. for j in range(self.num_kernels):
  610. if xs is None:
  611. xs = self.resblocks[i * self.num_kernels + j](x)
  612. else:
  613. xs += self.resblocks[i * self.num_kernels + j](x)
  614. x = xs / self.num_kernels
  615. x = F.leaky_relu(x)
  616. x = self.conv_post(x)
  617. magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
  618. phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
  619. x = self._istft(magnitude, phase)
  620. if finalize is False:
  621. x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
  622. x = torch.clamp(x, -self.audio_limit, self.audio_limit)
  623. return x
  624. @torch.inference_mode()
  625. def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
  626. # mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
  627. self.f0_predictor.to('cpu')
  628. f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
  629. # f0->source
  630. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  631. s, _, _ = self.m_source(s)
  632. s = s.transpose(1, 2)
  633. if finalize is True:
  634. generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
  635. else:
  636. generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
  637. return generated_speech, s
  638. if __name__ == '__main__':
  639. torch.backends.cudnn.deterministic = True
  640. torch.backends.cudnn.benchmark = False
  641. from hyperpyyaml import load_hyperpyyaml
  642. with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
  643. configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
  644. model = configs['hift']
  645. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  646. model.to(device)
  647. model.eval()
  648. max_len, chunk_size, context_size = 300, 30, 8
  649. mel = torch.rand(1, 80, max_len).to(device)
  650. pred_gt, _ = model.inference(mel)
  651. for i in range(0, max_len, chunk_size):
  652. finalize = True if i + chunk_size + context_size >= max_len else False
  653. pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
  654. pred_chunk = pred_chunk[:, i * 480:]
  655. print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())