generator.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """HIFI-GAN"""
  15. import typing as tp
  16. import numpy as np
  17. from scipy.signal import get_window
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from torch.nn import Conv1d
  22. from torch.nn import ConvTranspose1d
  23. from torch.nn.utils import remove_weight_norm
  24. from torch.nn.utils import weight_norm
  25. from torch.distributions.uniform import Uniform
  26. from cosyvoice.transformer.activation import Snake
  27. from academicodec.utils import get_padding
  28. from academicodec.utils import init_weights
  29. """hifigan based generator implementation.
  30. This code is modified from https://github.com/jik876/hifi-gan
  31. ,https://github.com/kan-bayashi/ParallelWaveGAN and
  32. https://github.com/NVIDIA/BigVGAN
  33. """
  34. class ResBlock(torch.nn.Module):
  35. """Residual block module in HiFiGAN/BigVGAN."""
  36. def __init__(
  37. self,
  38. channels: int = 512,
  39. kernel_size: int = 3,
  40. dilations: tp.List[int] = [1, 3, 5],
  41. ):
  42. super(ResBlock, self).__init__()
  43. self.convs1 = nn.ModuleList()
  44. self.convs2 = nn.ModuleList()
  45. for dilation in dilations:
  46. self.convs1.append(
  47. weight_norm(
  48. Conv1d(
  49. channels,
  50. channels,
  51. kernel_size,
  52. 1,
  53. dilation=dilation,
  54. padding=get_padding(kernel_size, dilation)
  55. )
  56. )
  57. )
  58. self.convs2.append(
  59. weight_norm(
  60. Conv1d(
  61. channels,
  62. channels,
  63. kernel_size,
  64. 1,
  65. dilation=1,
  66. padding=get_padding(kernel_size, 1)
  67. )
  68. )
  69. )
  70. self.convs1.apply(init_weights)
  71. self.convs2.apply(init_weights)
  72. self.activations1 = nn.ModuleList([
  73. Snake(channels, alpha_logscale=False)
  74. for _ in range(len(self.convs1))
  75. ])
  76. self.activations2 = nn.ModuleList([
  77. Snake(channels, alpha_logscale=False)
  78. for _ in range(len(self.convs2))
  79. ])
  80. def forward(self, x: torch.Tensor) -> torch.Tensor:
  81. for idx in range(len(self.convs1)):
  82. xt = self.activations1[idx](x)
  83. xt = self.convs1[idx](xt)
  84. xt = self.activations2[idx](xt)
  85. xt = self.convs2[idx](xt)
  86. x = xt + x
  87. return x
  88. def remove_weight_norm(self):
  89. for idx in range(len(self.convs1)):
  90. remove_weight_norm(self.convs1[idx])
  91. remove_weight_norm(self.convs2[idx])
  92. class SineGen(torch.nn.Module):
  93. """ Definition of sine generator
  94. SineGen(samp_rate, harmonic_num = 0,
  95. sine_amp = 0.1, noise_std = 0.003,
  96. voiced_threshold = 0,
  97. flag_for_pulse=False)
  98. samp_rate: sampling rate in Hz
  99. harmonic_num: number of harmonic overtones (default 0)
  100. sine_amp: amplitude of sine-wavefrom (default 0.1)
  101. noise_std: std of Gaussian noise (default 0.003)
  102. voiced_thoreshold: F0 threshold for U/V classification (default 0)
  103. flag_for_pulse: this SinGen is used inside PulseGen (default False)
  104. Note: when flag_for_pulse is True, the first time step of a voiced
  105. segment is always sin(np.pi) or cos(0)
  106. """
  107. def __init__(self, samp_rate, harmonic_num=0,
  108. sine_amp=0.1, noise_std=0.003,
  109. voiced_threshold=0):
  110. super(SineGen, self).__init__()
  111. self.sine_amp = sine_amp
  112. self.noise_std = noise_std
  113. self.harmonic_num = harmonic_num
  114. self.sampling_rate = samp_rate
  115. self.voiced_threshold = voiced_threshold
  116. def _f02uv(self, f0):
  117. # generate uv signal
  118. uv = (f0 > self.voiced_threshold).type(torch.float32)
  119. return uv
  120. @torch.no_grad()
  121. def forward(self, f0):
  122. """
  123. :param f0: [B, 1, sample_len], Hz
  124. :return: [B, 1, sample_len]
  125. """
  126. F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
  127. for i in range(self.harmonic_num + 1):
  128. F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
  129. theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
  130. u_dist = Uniform(low=-np.pi, high=np.pi)
  131. phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
  132. phase_vec[:, 0, :] = 0
  133. # generate sine waveforms
  134. sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
  135. # generate uv signal
  136. uv = self._f02uv(f0)
  137. # noise: for unvoiced should be similar to sine_amp
  138. # std = self.sine_amp/3 -> max value ~ self.sine_amp
  139. # . for voiced regions is self.noise_std
  140. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  141. noise = noise_amp * torch.randn_like(sine_waves)
  142. # first: set the unvoiced part to 0 by uv
  143. # then: additive noise
  144. sine_waves = sine_waves * uv + noise
  145. return sine_waves, uv, noise
  146. class SourceModuleHnNSF(torch.nn.Module):
  147. """ SourceModule for hn-nsf
  148. SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
  149. add_noise_std=0.003, voiced_threshod=0)
  150. sampling_rate: sampling_rate in Hz
  151. harmonic_num: number of harmonic above F0 (default: 0)
  152. sine_amp: amplitude of sine source signal (default: 0.1)
  153. add_noise_std: std of additive Gaussian noise (default: 0.003)
  154. note that amplitude of noise in unvoiced is decided
  155. by sine_amp
  156. voiced_threshold: threhold to set U/V given F0 (default: 0)
  157. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  158. F0_sampled (batchsize, length, 1)
  159. Sine_source (batchsize, length, 1)
  160. noise_source (batchsize, length 1)
  161. uv (batchsize, length, 1)
  162. """
  163. def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
  164. add_noise_std=0.003, voiced_threshod=0):
  165. super(SourceModuleHnNSF, self).__init__()
  166. self.sine_amp = sine_amp
  167. self.noise_std = add_noise_std
  168. # to produce sine waveforms
  169. self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
  170. sine_amp, add_noise_std, voiced_threshod)
  171. # to merge source harmonics into a single excitation
  172. self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
  173. self.l_tanh = torch.nn.Tanh()
  174. def forward(self, x):
  175. """
  176. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  177. F0_sampled (batchsize, length, 1)
  178. Sine_source (batchsize, length, 1)
  179. noise_source (batchsize, length 1)
  180. """
  181. # source for harmonic branch
  182. with torch.no_grad():
  183. sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
  184. sine_wavs = sine_wavs.transpose(1, 2)
  185. uv = uv.transpose(1, 2)
  186. sine_merge = self.l_tanh(self.l_linear(sine_wavs))
  187. # source for noise branch, in the same shape as uv
  188. noise = torch.randn_like(uv) * self.sine_amp / 3
  189. return sine_merge, noise, uv
  190. class HiFTGenerator(nn.Module):
  191. """
  192. HiFTNet Generator: Neural Source Filter + ISTFTNet
  193. https://arxiv.org/abs/2309.09493
  194. """
  195. def __init__(
  196. self,
  197. in_channels: int = 80,
  198. base_channels: int = 512,
  199. nb_harmonics: int = 8,
  200. sampling_rate: int = 22050,
  201. nsf_alpha: float = 0.1,
  202. nsf_sigma: float = 0.003,
  203. nsf_voiced_threshold: float = 10,
  204. upsample_rates: tp.List[int] = [8, 8],
  205. upsample_kernel_sizes: tp.List[int] = [16, 16],
  206. istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
  207. resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
  208. resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  209. source_resblock_kernel_sizes: tp.List[int] = [7, 11],
  210. source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
  211. lrelu_slope: float = 0.1,
  212. audio_limit: float = 0.99,
  213. f0_predictor: torch.nn.Module = None,
  214. ):
  215. super(HiFTGenerator, self).__init__()
  216. self.out_channels = 1
  217. self.nb_harmonics = nb_harmonics
  218. self.sampling_rate = sampling_rate
  219. self.istft_params = istft_params
  220. self.lrelu_slope = lrelu_slope
  221. self.audio_limit = audio_limit
  222. self.num_kernels = len(resblock_kernel_sizes)
  223. self.num_upsamples = len(upsample_rates)
  224. self.m_source = SourceModuleHnNSF(
  225. sampling_rate=sampling_rate,
  226. upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
  227. harmonic_num=nb_harmonics,
  228. sine_amp=nsf_alpha,
  229. add_noise_std=nsf_sigma,
  230. voiced_threshod=nsf_voiced_threshold)
  231. self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
  232. self.conv_pre = weight_norm(
  233. Conv1d(in_channels, base_channels, 7, 1, padding=3)
  234. )
  235. # Up
  236. self.ups = nn.ModuleList()
  237. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  238. self.ups.append(
  239. weight_norm(
  240. ConvTranspose1d(
  241. base_channels // (2**i),
  242. base_channels // (2**(i + 1)),
  243. k,
  244. u,
  245. padding=(k - u) // 2,
  246. )
  247. )
  248. )
  249. # Down
  250. self.source_downs = nn.ModuleList()
  251. self.source_resblocks = nn.ModuleList()
  252. downsample_rates = [1] + upsample_rates[::-1][:-1]
  253. downsample_cum_rates = np.cumprod(downsample_rates)
  254. for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
  255. source_resblock_dilation_sizes)):
  256. if u == 1:
  257. self.source_downs.append(
  258. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
  259. )
  260. else:
  261. self.source_downs.append(
  262. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
  263. )
  264. self.source_resblocks.append(
  265. ResBlock(base_channels // (2 ** (i + 1)), k, d)
  266. )
  267. self.resblocks = nn.ModuleList()
  268. for i in range(len(self.ups)):
  269. ch = base_channels // (2**(i + 1))
  270. for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
  271. self.resblocks.append(ResBlock(ch, k, d))
  272. self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
  273. self.ups.apply(init_weights)
  274. self.conv_post.apply(init_weights)
  275. self.reflection_pad = nn.ReflectionPad1d((1, 0))
  276. self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
  277. self.f0_predictor = f0_predictor
  278. def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
  279. f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  280. har_source, _, _ = self.m_source(f0)
  281. return har_source.transpose(1, 2)
  282. def _stft(self, x):
  283. spec = torch.stft(
  284. x,
  285. self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
  286. return_complex=True)
  287. spec = torch.view_as_real(spec) # [B, F, TT, 2]
  288. return spec[..., 0], spec[..., 1]
  289. def _istft(self, magnitude, phase):
  290. magnitude = torch.clip(magnitude, max=1e2)
  291. real = magnitude * torch.cos(phase)
  292. img = magnitude * torch.sin(phase)
  293. inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
  294. return inverse_transform
  295. def forward(self, x: torch.Tensor) -> torch.Tensor:
  296. f0 = self.f0_predictor(x)
  297. s = self._f02source(f0)
  298. s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
  299. s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
  300. x = self.conv_pre(x)
  301. for i in range(self.num_upsamples):
  302. x = F.leaky_relu(x, self.lrelu_slope)
  303. x = self.ups[i](x)
  304. if i == self.num_upsamples - 1:
  305. x = self.reflection_pad(x)
  306. # fusion
  307. si = self.source_downs[i](s_stft)
  308. si = self.source_resblocks[i](si)
  309. x = x + si
  310. xs = None
  311. for j in range(self.num_kernels):
  312. if xs is None:
  313. xs = self.resblocks[i * self.num_kernels + j](x)
  314. else:
  315. xs += self.resblocks[i * self.num_kernels + j](x)
  316. x = xs / self.num_kernels
  317. x = F.leaky_relu(x)
  318. x = self.conv_post(x)
  319. magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
  320. phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
  321. x = self._istft(magnitude, phase)
  322. x = torch.clamp(x, -self.audio_limit, self.audio_limit)
  323. return x
  324. def remove_weight_norm(self):
  325. print('Removing weight norm...')
  326. for l in self.ups:
  327. remove_weight_norm(l)
  328. for l in self.resblocks:
  329. l.remove_weight_norm()
  330. remove_weight_norm(self.conv_pre)
  331. remove_weight_norm(self.conv_post)
  332. self.source_module.remove_weight_norm()
  333. for l in self.source_downs:
  334. remove_weight_norm(l)
  335. for l in self.source_resblocks:
  336. l.remove_weight_norm()
  337. @torch.inference_mode()
  338. def inference(self, mel: torch.Tensor) -> torch.Tensor:
  339. return self.forward(x=mel)