generator.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """HIFI-GAN"""
  15. from typing import Dict, Optional, List
  16. import numpy as np
  17. from scipy.signal import get_window
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from torch.nn import Conv1d
  22. from torch.nn import ConvTranspose1d
  23. from torch.nn.utils import remove_weight_norm
  24. try:
  25. from torch.nn.utils.parametrizations import weight_norm
  26. except ImportError:
  27. from torch.nn.utils import weight_norm
  28. from torch.distributions.uniform import Uniform
  29. from cosyvoice.transformer.activation import Snake
  30. from cosyvoice.utils.common import get_padding
  31. from cosyvoice.utils.common import init_weights
  32. """hifigan based generator implementation.
  33. This code is modified from https://github.com/jik876/hifi-gan
  34. ,https://github.com/kan-bayashi/ParallelWaveGAN and
  35. https://github.com/NVIDIA/BigVGAN
  36. """
  37. class ResBlock(torch.nn.Module):
  38. """Residual block module in HiFiGAN/BigVGAN."""
  39. def __init__(
  40. self,
  41. channels: int = 512,
  42. kernel_size: int = 3,
  43. dilations: List[int] = [1, 3, 5],
  44. ):
  45. super(ResBlock, self).__init__()
  46. self.convs1 = nn.ModuleList()
  47. self.convs2 = nn.ModuleList()
  48. for dilation in dilations:
  49. self.convs1.append(
  50. weight_norm(
  51. Conv1d(
  52. channels,
  53. channels,
  54. kernel_size,
  55. 1,
  56. dilation=dilation,
  57. padding=get_padding(kernel_size, dilation)
  58. )
  59. )
  60. )
  61. self.convs2.append(
  62. weight_norm(
  63. Conv1d(
  64. channels,
  65. channels,
  66. kernel_size,
  67. 1,
  68. dilation=1,
  69. padding=get_padding(kernel_size, 1)
  70. )
  71. )
  72. )
  73. self.convs1.apply(init_weights)
  74. self.convs2.apply(init_weights)
  75. self.activations1 = nn.ModuleList([
  76. Snake(channels, alpha_logscale=False)
  77. for _ in range(len(self.convs1))
  78. ])
  79. self.activations2 = nn.ModuleList([
  80. Snake(channels, alpha_logscale=False)
  81. for _ in range(len(self.convs2))
  82. ])
  83. def forward(self, x: torch.Tensor) -> torch.Tensor:
  84. for idx in range(len(self.convs1)):
  85. xt = self.activations1[idx](x)
  86. xt = self.convs1[idx](xt)
  87. xt = self.activations2[idx](xt)
  88. xt = self.convs2[idx](xt)
  89. x = xt + x
  90. return x
  91. def remove_weight_norm(self):
  92. for idx in range(len(self.convs1)):
  93. remove_weight_norm(self.convs1[idx])
  94. remove_weight_norm(self.convs2[idx])
  95. class SineGen(torch.nn.Module):
  96. """ Definition of sine generator
  97. SineGen(samp_rate, harmonic_num = 0,
  98. sine_amp = 0.1, noise_std = 0.003,
  99. voiced_threshold = 0,
  100. flag_for_pulse=False)
  101. samp_rate: sampling rate in Hz
  102. harmonic_num: number of harmonic overtones (default 0)
  103. sine_amp: amplitude of sine-wavefrom (default 0.1)
  104. noise_std: std of Gaussian noise (default 0.003)
  105. voiced_thoreshold: F0 threshold for U/V classification (default 0)
  106. flag_for_pulse: this SinGen is used inside PulseGen (default False)
  107. Note: when flag_for_pulse is True, the first time step of a voiced
  108. segment is always sin(np.pi) or cos(0)
  109. """
  110. def __init__(self, samp_rate, harmonic_num=0,
  111. sine_amp=0.1, noise_std=0.003,
  112. voiced_threshold=0):
  113. super(SineGen, self).__init__()
  114. self.sine_amp = sine_amp
  115. self.noise_std = noise_std
  116. self.harmonic_num = harmonic_num
  117. self.sampling_rate = samp_rate
  118. self.voiced_threshold = voiced_threshold
  119. def _f02uv(self, f0):
  120. # generate uv signal
  121. uv = (f0 > self.voiced_threshold).type(torch.float32)
  122. return uv
  123. @torch.no_grad()
  124. def forward(self, f0):
  125. """
  126. :param f0: [B, 1, sample_len], Hz
  127. :return: [B, 1, sample_len]
  128. """
  129. F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
  130. for i in range(self.harmonic_num + 1):
  131. F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
  132. theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
  133. u_dist = Uniform(low=-np.pi, high=np.pi)
  134. phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
  135. phase_vec[:, 0, :] = 0
  136. # generate sine waveforms
  137. sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
  138. # generate uv signal
  139. uv = self._f02uv(f0)
  140. # noise: for unvoiced should be similar to sine_amp
  141. # std = self.sine_amp/3 -> max value ~ self.sine_amp
  142. # . for voiced regions is self.noise_std
  143. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  144. noise = noise_amp * torch.randn_like(sine_waves)
  145. # first: set the unvoiced part to 0 by uv
  146. # then: additive noise
  147. sine_waves = sine_waves * uv + noise
  148. return sine_waves, uv, noise
  149. class SourceModuleHnNSF(torch.nn.Module):
  150. """ SourceModule for hn-nsf
  151. SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
  152. add_noise_std=0.003, voiced_threshod=0)
  153. sampling_rate: sampling_rate in Hz
  154. harmonic_num: number of harmonic above F0 (default: 0)
  155. sine_amp: amplitude of sine source signal (default: 0.1)
  156. add_noise_std: std of additive Gaussian noise (default: 0.003)
  157. note that amplitude of noise in unvoiced is decided
  158. by sine_amp
  159. voiced_threshold: threhold to set U/V given F0 (default: 0)
  160. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  161. F0_sampled (batchsize, length, 1)
  162. Sine_source (batchsize, length, 1)
  163. noise_source (batchsize, length 1)
  164. uv (batchsize, length, 1)
  165. """
  166. def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
  167. add_noise_std=0.003, voiced_threshod=0):
  168. super(SourceModuleHnNSF, self).__init__()
  169. self.sine_amp = sine_amp
  170. self.noise_std = add_noise_std
  171. # to produce sine waveforms
  172. self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
  173. sine_amp, add_noise_std, voiced_threshod)
  174. # to merge source harmonics into a single excitation
  175. self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
  176. self.l_tanh = torch.nn.Tanh()
  177. def forward(self, x):
  178. """
  179. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  180. F0_sampled (batchsize, length, 1)
  181. Sine_source (batchsize, length, 1)
  182. noise_source (batchsize, length 1)
  183. """
  184. # source for harmonic branch
  185. with torch.no_grad():
  186. sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
  187. sine_wavs = sine_wavs.transpose(1, 2)
  188. uv = uv.transpose(1, 2)
  189. sine_merge = self.l_tanh(self.l_linear(sine_wavs))
  190. # source for noise branch, in the same shape as uv
  191. noise = torch.randn_like(uv) * self.sine_amp / 3
  192. return sine_merge, noise, uv
  193. class HiFTGenerator(nn.Module):
  194. """
  195. HiFTNet Generator: Neural Source Filter + ISTFTNet
  196. https://arxiv.org/abs/2309.09493
  197. """
  198. def __init__(
  199. self,
  200. in_channels: int = 80,
  201. base_channels: int = 512,
  202. nb_harmonics: int = 8,
  203. sampling_rate: int = 22050,
  204. nsf_alpha: float = 0.1,
  205. nsf_sigma: float = 0.003,
  206. nsf_voiced_threshold: float = 10,
  207. upsample_rates: List[int] = [8, 8],
  208. upsample_kernel_sizes: List[int] = [16, 16],
  209. istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
  210. resblock_kernel_sizes: List[int] = [3, 7, 11],
  211. resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  212. source_resblock_kernel_sizes: List[int] = [7, 11],
  213. source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
  214. lrelu_slope: float = 0.1,
  215. audio_limit: float = 0.99,
  216. f0_predictor: torch.nn.Module = None,
  217. ):
  218. super(HiFTGenerator, self).__init__()
  219. self.out_channels = 1
  220. self.nb_harmonics = nb_harmonics
  221. self.sampling_rate = sampling_rate
  222. self.istft_params = istft_params
  223. self.lrelu_slope = lrelu_slope
  224. self.audio_limit = audio_limit
  225. self.num_kernels = len(resblock_kernel_sizes)
  226. self.num_upsamples = len(upsample_rates)
  227. self.m_source = SourceModuleHnNSF(
  228. sampling_rate=sampling_rate,
  229. upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
  230. harmonic_num=nb_harmonics,
  231. sine_amp=nsf_alpha,
  232. add_noise_std=nsf_sigma,
  233. voiced_threshod=nsf_voiced_threshold)
  234. self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
  235. self.conv_pre = weight_norm(
  236. Conv1d(in_channels, base_channels, 7, 1, padding=3)
  237. )
  238. # Up
  239. self.ups = nn.ModuleList()
  240. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  241. self.ups.append(
  242. weight_norm(
  243. ConvTranspose1d(
  244. base_channels // (2**i),
  245. base_channels // (2**(i + 1)),
  246. k,
  247. u,
  248. padding=(k - u) // 2,
  249. )
  250. )
  251. )
  252. # Down
  253. self.source_downs = nn.ModuleList()
  254. self.source_resblocks = nn.ModuleList()
  255. downsample_rates = [1] + upsample_rates[::-1][:-1]
  256. downsample_cum_rates = np.cumprod(downsample_rates)
  257. for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
  258. if u == 1:
  259. self.source_downs.append(
  260. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
  261. )
  262. else:
  263. self.source_downs.append(
  264. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
  265. )
  266. self.source_resblocks.append(
  267. ResBlock(base_channels // (2 ** (i + 1)), k, d)
  268. )
  269. self.resblocks = nn.ModuleList()
  270. for i in range(len(self.ups)):
  271. ch = base_channels // (2**(i + 1))
  272. for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
  273. self.resblocks.append(ResBlock(ch, k, d))
  274. self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
  275. self.ups.apply(init_weights)
  276. self.conv_post.apply(init_weights)
  277. self.reflection_pad = nn.ReflectionPad1d((1, 0))
  278. self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
  279. self.f0_predictor = f0_predictor
  280. def remove_weight_norm(self):
  281. print('Removing weight norm...')
  282. for l in self.ups:
  283. remove_weight_norm(l)
  284. for l in self.resblocks:
  285. l.remove_weight_norm()
  286. remove_weight_norm(self.conv_pre)
  287. remove_weight_norm(self.conv_post)
  288. self.m_source.remove_weight_norm()
  289. for l in self.source_downs:
  290. remove_weight_norm(l)
  291. for l in self.source_resblocks:
  292. l.remove_weight_norm()
  293. def _stft(self, x):
  294. spec = torch.stft(
  295. x,
  296. self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
  297. return_complex=True)
  298. spec = torch.view_as_real(spec) # [B, F, TT, 2]
  299. return spec[..., 0], spec[..., 1]
  300. def _istft(self, magnitude, phase):
  301. magnitude = torch.clip(magnitude, max=1e2)
  302. real = magnitude * torch.cos(phase)
  303. img = magnitude * torch.sin(phase)
  304. inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
  305. self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
  306. return inverse_transform
  307. def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
  308. s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
  309. s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
  310. x = self.conv_pre(x)
  311. for i in range(self.num_upsamples):
  312. x = F.leaky_relu(x, self.lrelu_slope)
  313. x = self.ups[i](x)
  314. if i == self.num_upsamples - 1:
  315. x = self.reflection_pad(x)
  316. # fusion
  317. si = self.source_downs[i](s_stft)
  318. si = self.source_resblocks[i](si)
  319. x = x + si
  320. xs = None
  321. for j in range(self.num_kernels):
  322. if xs is None:
  323. xs = self.resblocks[i * self.num_kernels + j](x)
  324. else:
  325. xs += self.resblocks[i * self.num_kernels + j](x)
  326. x = xs / self.num_kernels
  327. x = F.leaky_relu(x)
  328. x = self.conv_post(x)
  329. magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
  330. phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
  331. x = self._istft(magnitude, phase)
  332. x = torch.clamp(x, -self.audio_limit, self.audio_limit)
  333. return x
  334. def forward(
  335. self,
  336. batch: dict,
  337. device: torch.device,
  338. ) -> Dict[str, Optional[torch.Tensor]]:
  339. speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
  340. # mel->f0
  341. f0 = self.f0_predictor(speech_feat)
  342. # f0->source
  343. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  344. s, _, _ = self.m_source(s)
  345. s = s.transpose(1, 2)
  346. # mel+source->speech
  347. generated_speech = self.decode(x=speech_feat, s=s)
  348. return generated_speech, f0
  349. @torch.inference_mode()
  350. def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
  351. # mel->f0
  352. f0 = self.f0_predictor(speech_feat)
  353. # f0->source
  354. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  355. s, _, _ = self.m_source(s)
  356. s = s.transpose(1, 2)
  357. # use cache_source to avoid glitch
  358. if cache_source.shape[2] != 0:
  359. s[:, :, :cache_source.shape[2]] = cache_source
  360. generated_speech = self.decode(x=speech_feat, s=s)
  361. return generated_speech, s