generator.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """HIFI-GAN"""
  15. from typing import Dict, Optional, List
  16. import numpy as np
  17. from scipy.signal import get_window
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from torch.nn import Conv1d
  22. from torch.nn import ConvTranspose1d
  23. from torch.nn.utils import remove_weight_norm
  24. try:
  25. from torch.nn.utils.parametrizations import weight_norm
  26. except ImportError:
  27. from torch.nn.utils import weight_norm
  28. from torch.distributions.uniform import Uniform
  29. from cosyvoice.transformer.activation import Snake
  30. from cosyvoice.utils.common import get_padding
  31. from cosyvoice.utils.common import init_weights
  32. """hifigan based generator implementation.
  33. This code is modified from https://github.com/jik876/hifi-gan
  34. ,https://github.com/kan-bayashi/ParallelWaveGAN and
  35. https://github.com/NVIDIA/BigVGAN
  36. """
  37. class ResBlock(torch.nn.Module):
  38. """Residual block module in HiFiGAN/BigVGAN."""
  39. def __init__(
  40. self,
  41. channels: int = 512,
  42. kernel_size: int = 3,
  43. dilations: List[int] = [1, 3, 5],
  44. ):
  45. super(ResBlock, self).__init__()
  46. self.convs1 = nn.ModuleList()
  47. self.convs2 = nn.ModuleList()
  48. for dilation in dilations:
  49. self.convs1.append(
  50. weight_norm(
  51. Conv1d(
  52. channels,
  53. channels,
  54. kernel_size,
  55. 1,
  56. dilation=dilation,
  57. padding=get_padding(kernel_size, dilation)
  58. )
  59. )
  60. )
  61. self.convs2.append(
  62. weight_norm(
  63. Conv1d(
  64. channels,
  65. channels,
  66. kernel_size,
  67. 1,
  68. dilation=1,
  69. padding=get_padding(kernel_size, 1)
  70. )
  71. )
  72. )
  73. self.convs1.apply(init_weights)
  74. self.convs2.apply(init_weights)
  75. self.activations1 = nn.ModuleList([
  76. Snake(channels, alpha_logscale=False)
  77. for _ in range(len(self.convs1))
  78. ])
  79. self.activations2 = nn.ModuleList([
  80. Snake(channels, alpha_logscale=False)
  81. for _ in range(len(self.convs2))
  82. ])
  83. def forward(self, x: torch.Tensor) -> torch.Tensor:
  84. for idx in range(len(self.convs1)):
  85. xt = self.activations1[idx](x)
  86. xt = self.convs1[idx](xt)
  87. xt = self.activations2[idx](xt)
  88. xt = self.convs2[idx](xt)
  89. x = xt + x
  90. return x
  91. def remove_weight_norm(self):
  92. for idx in range(len(self.convs1)):
  93. remove_weight_norm(self.convs1[idx])
  94. remove_weight_norm(self.convs2[idx])
  95. class SineGen(torch.nn.Module):
  96. """ Definition of sine generator
  97. SineGen(samp_rate, harmonic_num = 0,
  98. sine_amp = 0.1, noise_std = 0.003,
  99. voiced_threshold = 0,
  100. flag_for_pulse=False)
  101. samp_rate: sampling rate in Hz
  102. harmonic_num: number of harmonic overtones (default 0)
  103. sine_amp: amplitude of sine-wavefrom (default 0.1)
  104. noise_std: std of Gaussian noise (default 0.003)
  105. voiced_thoreshold: F0 threshold for U/V classification (default 0)
  106. flag_for_pulse: this SinGen is used inside PulseGen (default False)
  107. Note: when flag_for_pulse is True, the first time step of a voiced
  108. segment is always sin(np.pi) or cos(0)
  109. """
  110. def __init__(self, samp_rate, harmonic_num=0,
  111. sine_amp=0.1, noise_std=0.003,
  112. voiced_threshold=0):
  113. super(SineGen, self).__init__()
  114. self.sine_amp = sine_amp
  115. self.noise_std = noise_std
  116. self.harmonic_num = harmonic_num
  117. self.sampling_rate = samp_rate
  118. self.voiced_threshold = voiced_threshold
  119. def _f02uv(self, f0):
  120. # generate uv signal
  121. uv = (f0 > self.voiced_threshold).type(torch.float32)
  122. return uv
  123. @torch.no_grad()
  124. def forward(self, f0):
  125. """
  126. :param f0: [B, 1, sample_len], Hz
  127. :return: [B, 1, sample_len]
  128. """
  129. F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
  130. for i in range(self.harmonic_num + 1):
  131. F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
  132. theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
  133. u_dist = Uniform(low=-np.pi, high=np.pi)
  134. phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
  135. phase_vec[:, 0, :] = 0
  136. # generate sine waveforms
  137. sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
  138. # generate uv signal
  139. uv = self._f02uv(f0)
  140. # noise: for unvoiced should be similar to sine_amp
  141. # std = self.sine_amp/3 -> max value ~ self.sine_amp
  142. # . for voiced regions is self.noise_std
  143. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  144. noise = noise_amp * torch.randn_like(sine_waves)
  145. # first: set the unvoiced part to 0 by uv
  146. # then: additive noise
  147. sine_waves = sine_waves * uv + noise
  148. return sine_waves, uv, noise
  149. class SourceModuleHnNSF(torch.nn.Module):
  150. """ SourceModule for hn-nsf
  151. SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
  152. add_noise_std=0.003, voiced_threshod=0)
  153. sampling_rate: sampling_rate in Hz
  154. harmonic_num: number of harmonic above F0 (default: 0)
  155. sine_amp: amplitude of sine source signal (default: 0.1)
  156. add_noise_std: std of additive Gaussian noise (default: 0.003)
  157. note that amplitude of noise in unvoiced is decided
  158. by sine_amp
  159. voiced_threshold: threhold to set U/V given F0 (default: 0)
  160. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  161. F0_sampled (batchsize, length, 1)
  162. Sine_source (batchsize, length, 1)
  163. noise_source (batchsize, length 1)
  164. uv (batchsize, length, 1)
  165. """
  166. def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
  167. add_noise_std=0.003, voiced_threshod=0):
  168. super(SourceModuleHnNSF, self).__init__()
  169. self.sine_amp = sine_amp
  170. self.noise_std = add_noise_std
  171. # to produce sine waveforms
  172. self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
  173. sine_amp, add_noise_std, voiced_threshod)
  174. # to merge source harmonics into a single excitation
  175. self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
  176. self.l_tanh = torch.nn.Tanh()
  177. def forward(self, x):
  178. """
  179. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  180. F0_sampled (batchsize, length, 1)
  181. Sine_source (batchsize, length, 1)
  182. noise_source (batchsize, length 1)
  183. """
  184. # source for harmonic branch
  185. with torch.no_grad():
  186. sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
  187. sine_wavs = sine_wavs.transpose(1, 2)
  188. uv = uv.transpose(1, 2)
  189. sine_merge = self.l_tanh(self.l_linear(sine_wavs))
  190. # source for noise branch, in the same shape as uv
  191. noise = torch.randn_like(uv) * self.sine_amp / 3
  192. return sine_merge, noise, uv
  193. class SineGen2(torch.nn.Module):
  194. """ Definition of sine generator
  195. SineGen(samp_rate, harmonic_num = 0,
  196. sine_amp = 0.1, noise_std = 0.003,
  197. voiced_threshold = 0,
  198. flag_for_pulse=False)
  199. samp_rate: sampling rate in Hz
  200. harmonic_num: number of harmonic overtones (default 0)
  201. sine_amp: amplitude of sine-wavefrom (default 0.1)
  202. noise_std: std of Gaussian noise (default 0.003)
  203. voiced_thoreshold: F0 threshold for U/V classification (default 0)
  204. flag_for_pulse: this SinGen is used inside PulseGen (default False)
  205. Note: when flag_for_pulse is True, the first time step of a voiced
  206. segment is always sin(np.pi) or cos(0)
  207. """
  208. def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
  209. sine_amp=0.1, noise_std=0.003,
  210. voiced_threshold=0,
  211. flag_for_pulse=False):
  212. super(SineGen2, self).__init__()
  213. self.sine_amp = sine_amp
  214. self.noise_std = noise_std
  215. self.harmonic_num = harmonic_num
  216. self.dim = self.harmonic_num + 1
  217. self.sampling_rate = samp_rate
  218. self.voiced_threshold = voiced_threshold
  219. self.flag_for_pulse = flag_for_pulse
  220. self.upsample_scale = upsample_scale
  221. def _f02uv(self, f0):
  222. # generate uv signal
  223. uv = (f0 > self.voiced_threshold).type(torch.float32)
  224. return uv
  225. def _f02sine(self, f0_values):
  226. """ f0_values: (batchsize, length, dim)
  227. where dim indicates fundamental tone and overtones
  228. """
  229. # convert to F0 in rad. The interger part n can be ignored
  230. # because 2 * np.pi * n doesn't affect phase
  231. rad_values = (f0_values / self.sampling_rate) % 1
  232. # initial phase noise (no noise for fundamental component)
  233. rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
  234. rand_ini[:, 0] = 0
  235. rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
  236. # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
  237. if not self.flag_for_pulse:
  238. rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
  239. scale_factor=1 / self.upsample_scale,
  240. mode="linear").transpose(1, 2)
  241. phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
  242. phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
  243. scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
  244. sines = torch.sin(phase)
  245. else:
  246. # If necessary, make sure that the first time step of every
  247. # voiced segments is sin(pi) or cos(0)
  248. # This is used for pulse-train generation
  249. # identify the last time step in unvoiced segments
  250. uv = self._f02uv(f0_values)
  251. uv_1 = torch.roll(uv, shifts=-1, dims=1)
  252. uv_1[:, -1, :] = 1
  253. u_loc = (uv < 1) * (uv_1 > 0)
  254. # get the instantanouse phase
  255. tmp_cumsum = torch.cumsum(rad_values, dim=1)
  256. # different batch needs to be processed differently
  257. for idx in range(f0_values.shape[0]):
  258. temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
  259. temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
  260. # stores the accumulation of i.phase within
  261. # each voiced segments
  262. tmp_cumsum[idx, :, :] = 0
  263. tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
  264. # rad_values - tmp_cumsum: remove the accumulation of i.phase
  265. # within the previous voiced segment.
  266. i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
  267. # get the sines
  268. sines = torch.cos(i_phase * 2 * np.pi)
  269. return sines
  270. def forward(self, f0):
  271. """ sine_tensor, uv = forward(f0)
  272. input F0: tensor(batchsize=1, length, dim=1)
  273. f0 for unvoiced steps should be 0
  274. output sine_tensor: tensor(batchsize=1, length, dim)
  275. output uv: tensor(batchsize=1, length, 1)
  276. """
  277. # fundamental component
  278. fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
  279. # generate sine waveforms
  280. sine_waves = self._f02sine(fn) * self.sine_amp
  281. # generate uv signal
  282. uv = self._f02uv(f0)
  283. # noise: for unvoiced should be similar to sine_amp
  284. # std = self.sine_amp/3 -> max value ~ self.sine_amp
  285. # . for voiced regions is self.noise_std
  286. noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
  287. noise = noise_amp * torch.randn_like(sine_waves)
  288. # first: set the unvoiced part to 0 by uv
  289. # then: additive noise
  290. sine_waves = sine_waves * uv + noise
  291. return sine_waves, uv, noise
  292. class SourceModuleHnNSF2(torch.nn.Module):
  293. """ SourceModule for hn-nsf
  294. SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
  295. add_noise_std=0.003, voiced_threshod=0)
  296. sampling_rate: sampling_rate in Hz
  297. harmonic_num: number of harmonic above F0 (default: 0)
  298. sine_amp: amplitude of sine source signal (default: 0.1)
  299. add_noise_std: std of additive Gaussian noise (default: 0.003)
  300. note that amplitude of noise in unvoiced is decided
  301. by sine_amp
  302. voiced_threshold: threhold to set U/V given F0 (default: 0)
  303. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  304. F0_sampled (batchsize, length, 1)
  305. Sine_source (batchsize, length, 1)
  306. noise_source (batchsize, length 1)
  307. uv (batchsize, length, 1)
  308. """
  309. def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
  310. add_noise_std=0.003, voiced_threshod=0):
  311. super(SourceModuleHnNSF2, self).__init__()
  312. self.sine_amp = sine_amp
  313. self.noise_std = add_noise_std
  314. # to produce sine waveforms
  315. self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
  316. sine_amp, add_noise_std, voiced_threshod)
  317. # to merge source harmonics into a single excitation
  318. self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
  319. self.l_tanh = torch.nn.Tanh()
  320. def forward(self, x):
  321. """
  322. Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
  323. F0_sampled (batchsize, length, 1)
  324. Sine_source (batchsize, length, 1)
  325. noise_source (batchsize, length 1)
  326. """
  327. # source for harmonic branch
  328. with torch.no_grad():
  329. sine_wavs, uv, _ = self.l_sin_gen(x)
  330. sine_merge = self.l_tanh(self.l_linear(sine_wavs))
  331. # source for noise branch, in the same shape as uv
  332. noise = torch.randn_like(uv) * self.sine_amp / 3
  333. return sine_merge, noise, uv
  334. class HiFTGenerator(nn.Module):
  335. """
  336. HiFTNet Generator: Neural Source Filter + ISTFTNet
  337. https://arxiv.org/abs/2309.09493
  338. """
  339. def __init__(
  340. self,
  341. in_channels: int = 80,
  342. base_channels: int = 512,
  343. nb_harmonics: int = 8,
  344. sampling_rate: int = 22050,
  345. nsf_alpha: float = 0.1,
  346. nsf_sigma: float = 0.003,
  347. nsf_voiced_threshold: float = 10,
  348. upsample_rates: List[int] = [8, 8],
  349. upsample_kernel_sizes: List[int] = [16, 16],
  350. istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
  351. resblock_kernel_sizes: List[int] = [3, 7, 11],
  352. resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  353. source_resblock_kernel_sizes: List[int] = [7, 11],
  354. source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
  355. lrelu_slope: float = 0.1,
  356. audio_limit: float = 0.99,
  357. f0_predictor: torch.nn.Module = None,
  358. ):
  359. super(HiFTGenerator, self).__init__()
  360. self.out_channels = 1
  361. self.nb_harmonics = nb_harmonics
  362. self.sampling_rate = sampling_rate
  363. self.istft_params = istft_params
  364. self.lrelu_slope = lrelu_slope
  365. self.audio_limit = audio_limit
  366. self.num_kernels = len(resblock_kernel_sizes)
  367. self.num_upsamples = len(upsample_rates)
  368. # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
  369. this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
  370. self.m_source = this_SourceModuleHnNSF(
  371. sampling_rate=sampling_rate,
  372. upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
  373. harmonic_num=nb_harmonics,
  374. sine_amp=nsf_alpha,
  375. add_noise_std=nsf_sigma,
  376. voiced_threshod=nsf_voiced_threshold)
  377. self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
  378. self.conv_pre = weight_norm(
  379. Conv1d(in_channels, base_channels, 7, 1, padding=3)
  380. )
  381. # Up
  382. self.ups = nn.ModuleList()
  383. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  384. self.ups.append(
  385. weight_norm(
  386. ConvTranspose1d(
  387. base_channels // (2**i),
  388. base_channels // (2**(i + 1)),
  389. k,
  390. u,
  391. padding=(k - u) // 2,
  392. )
  393. )
  394. )
  395. # Down
  396. self.source_downs = nn.ModuleList()
  397. self.source_resblocks = nn.ModuleList()
  398. downsample_rates = [1] + upsample_rates[::-1][:-1]
  399. downsample_cum_rates = np.cumprod(downsample_rates)
  400. for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
  401. if u == 1:
  402. self.source_downs.append(
  403. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
  404. )
  405. else:
  406. self.source_downs.append(
  407. Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
  408. )
  409. self.source_resblocks.append(
  410. ResBlock(base_channels // (2 ** (i + 1)), k, d)
  411. )
  412. self.resblocks = nn.ModuleList()
  413. for i in range(len(self.ups)):
  414. ch = base_channels // (2**(i + 1))
  415. for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
  416. self.resblocks.append(ResBlock(ch, k, d))
  417. self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
  418. self.ups.apply(init_weights)
  419. self.conv_post.apply(init_weights)
  420. self.reflection_pad = nn.ReflectionPad1d((1, 0))
  421. self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
  422. self.f0_predictor = f0_predictor
  423. def remove_weight_norm(self):
  424. print('Removing weight norm...')
  425. for l in self.ups:
  426. remove_weight_norm(l)
  427. for l in self.resblocks:
  428. l.remove_weight_norm()
  429. remove_weight_norm(self.conv_pre)
  430. remove_weight_norm(self.conv_post)
  431. self.m_source.remove_weight_norm()
  432. for l in self.source_downs:
  433. remove_weight_norm(l)
  434. for l in self.source_resblocks:
  435. l.remove_weight_norm()
  436. def _stft(self, x):
  437. spec = torch.stft(
  438. x,
  439. self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
  440. return_complex=True)
  441. spec = torch.view_as_real(spec) # [B, F, TT, 2]
  442. return spec[..., 0], spec[..., 1]
  443. def _istft(self, magnitude, phase):
  444. magnitude = torch.clip(magnitude, max=1e2)
  445. real = magnitude * torch.cos(phase)
  446. img = magnitude * torch.sin(phase)
  447. inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
  448. self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
  449. return inverse_transform
  450. def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
  451. s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
  452. s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
  453. x = self.conv_pre(x)
  454. for i in range(self.num_upsamples):
  455. x = F.leaky_relu(x, self.lrelu_slope)
  456. x = self.ups[i](x)
  457. if i == self.num_upsamples - 1:
  458. x = self.reflection_pad(x)
  459. # fusion
  460. si = self.source_downs[i](s_stft)
  461. si = self.source_resblocks[i](si)
  462. x = x + si
  463. xs = None
  464. for j in range(self.num_kernels):
  465. if xs is None:
  466. xs = self.resblocks[i * self.num_kernels + j](x)
  467. else:
  468. xs += self.resblocks[i * self.num_kernels + j](x)
  469. x = xs / self.num_kernels
  470. x = F.leaky_relu(x)
  471. x = self.conv_post(x)
  472. magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
  473. phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
  474. x = self._istft(magnitude, phase)
  475. x = torch.clamp(x, -self.audio_limit, self.audio_limit)
  476. return x
  477. def forward(
  478. self,
  479. batch: dict,
  480. device: torch.device,
  481. ) -> Dict[str, Optional[torch.Tensor]]:
  482. speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
  483. # mel->f0
  484. f0 = self.f0_predictor(speech_feat)
  485. # f0->source
  486. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  487. s, _, _ = self.m_source(s)
  488. s = s.transpose(1, 2)
  489. # mel+source->speech
  490. generated_speech = self.decode(x=speech_feat, s=s)
  491. return generated_speech, f0
  492. @torch.inference_mode()
  493. def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
  494. # mel->f0
  495. f0 = self.f0_predictor(speech_feat)
  496. # f0->source
  497. s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
  498. s, _, _ = self.m_source(s)
  499. s = s.transpose(1, 2)
  500. # use cache_source to avoid glitch
  501. if cache_source.shape[2] != 0:
  502. s[:, :, :cache_source.shape[2]] = cache_source
  503. generated_speech = self.decode(x=speech_feat, s=s)
  504. return generated_speech, s