| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """HIFI-GAN"""
- import typing as tp
- import numpy as np
- from scipy.signal import get_window
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn import Conv1d
- from torch.nn import ConvTranspose1d
- from torch.nn.utils import remove_weight_norm
- from torch.nn.utils import weight_norm
- from torch.distributions.uniform import Uniform
- from cosyvoice.transformer.activation import Snake
- from academicodec.utils import get_padding
- from academicodec.utils import init_weights
- """hifigan based generator implementation.
- This code is modified from https://github.com/jik876/hifi-gan
- ,https://github.com/kan-bayashi/ParallelWaveGAN and
- https://github.com/NVIDIA/BigVGAN
- """
- class ResBlock(torch.nn.Module):
- """Residual block module in HiFiGAN/BigVGAN."""
- def __init__(
- self,
- channels: int = 512,
- kernel_size: int = 3,
- dilations: tp.List[int] = [1, 3, 5],
- ):
- super(ResBlock, self).__init__()
- self.convs1 = nn.ModuleList()
- self.convs2 = nn.ModuleList()
- for dilation in dilations:
- self.convs1.append(
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation,
- padding=get_padding(kernel_size, dilation)
- )
- )
- )
- self.convs2.append(
- weight_norm(
- Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1)
- )
- )
- )
- self.convs1.apply(init_weights)
- self.convs2.apply(init_weights)
- self.activations1 = nn.ModuleList([
- Snake(channels, alpha_logscale=False)
- for _ in range(len(self.convs1))
- ])
- self.activations2 = nn.ModuleList([
- Snake(channels, alpha_logscale=False)
- for _ in range(len(self.convs2))
- ])
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- for idx in range(len(self.convs1)):
- xt = self.activations1[idx](x)
- xt = self.convs1[idx](xt)
- xt = self.activations2[idx](xt)
- xt = self.convs2[idx](xt)
- x = xt + x
- return x
- def remove_weight_norm(self):
- for idx in range(len(self.convs1)):
- remove_weight_norm(self.convs1[idx])
- remove_weight_norm(self.convs2[idx])
- class SineGen(torch.nn.Module):
- """ Definition of sine generator
- SineGen(samp_rate, harmonic_num = 0,
- sine_amp = 0.1, noise_std = 0.003,
- voiced_threshold = 0,
- flag_for_pulse=False)
- samp_rate: sampling rate in Hz
- harmonic_num: number of harmonic overtones (default 0)
- sine_amp: amplitude of sine-wavefrom (default 0.1)
- noise_std: std of Gaussian noise (default 0.003)
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
- Note: when flag_for_pulse is True, the first time step of a voiced
- segment is always sin(np.pi) or cos(0)
- """
- def __init__(self, samp_rate, harmonic_num=0,
- sine_amp=0.1, noise_std=0.003,
- voiced_threshold=0):
- super(SineGen, self).__init__()
- self.sine_amp = sine_amp
- self.noise_std = noise_std
- self.harmonic_num = harmonic_num
- self.sampling_rate = samp_rate
- self.voiced_threshold = voiced_threshold
- def _f02uv(self, f0):
- # generate uv signal
- uv = (f0 > self.voiced_threshold).type(torch.float32)
- return uv
- @torch.no_grad()
- def forward(self, f0):
- """
- :param f0: [B, 1, sample_len], Hz
- :return: [B, 1, sample_len]
- """
- F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
- for i in range(self.harmonic_num + 1):
- F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
- theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
- u_dist = Uniform(low=-np.pi, high=np.pi)
- phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
- phase_vec[:, 0, :] = 0
- # generate sine waveforms
- sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
- # generate uv signal
- uv = self._f02uv(f0)
- # noise: for unvoiced should be similar to sine_amp
- # std = self.sine_amp/3 -> max value ~ self.sine_amp
- # . for voiced regions is self.noise_std
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
- noise = noise_amp * torch.randn_like(sine_waves)
- # first: set the unvoiced part to 0 by uv
- # then: additive noise
- sine_waves = sine_waves * uv + noise
- return sine_waves, uv, noise
- class SourceModuleHnNSF(torch.nn.Module):
- """ SourceModule for hn-nsf
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
- add_noise_std=0.003, voiced_threshod=0)
- sampling_rate: sampling_rate in Hz
- harmonic_num: number of harmonic above F0 (default: 0)
- sine_amp: amplitude of sine source signal (default: 0.1)
- add_noise_std: std of additive Gaussian noise (default: 0.003)
- note that amplitude of noise in unvoiced is decided
- by sine_amp
- voiced_threshold: threhold to set U/V given F0 (default: 0)
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
- F0_sampled (batchsize, length, 1)
- Sine_source (batchsize, length, 1)
- noise_source (batchsize, length 1)
- uv (batchsize, length, 1)
- """
- def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
- add_noise_std=0.003, voiced_threshod=0):
- super(SourceModuleHnNSF, self).__init__()
- self.sine_amp = sine_amp
- self.noise_std = add_noise_std
- # to produce sine waveforms
- self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
- sine_amp, add_noise_std, voiced_threshod)
- # to merge source harmonics into a single excitation
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
- self.l_tanh = torch.nn.Tanh()
- def forward(self, x):
- """
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
- F0_sampled (batchsize, length, 1)
- Sine_source (batchsize, length, 1)
- noise_source (batchsize, length 1)
- """
- # source for harmonic branch
- with torch.no_grad():
- sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
- sine_wavs = sine_wavs.transpose(1, 2)
- uv = uv.transpose(1, 2)
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
- # source for noise branch, in the same shape as uv
- noise = torch.randn_like(uv) * self.sine_amp / 3
- return sine_merge, noise, uv
- class HiFTGenerator(nn.Module):
- """
- HiFTNet Generator: Neural Source Filter + ISTFTNet
- https://arxiv.org/abs/2309.09493
- """
- def __init__(
- self,
- in_channels: int = 80,
- base_channels: int = 512,
- nb_harmonics: int = 8,
- sampling_rate: int = 22050,
- nsf_alpha: float = 0.1,
- nsf_sigma: float = 0.003,
- nsf_voiced_threshold: float = 10,
- upsample_rates: tp.List[int] = [8, 8],
- upsample_kernel_sizes: tp.List[int] = [16, 16],
- istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
- resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
- resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
- source_resblock_kernel_sizes: tp.List[int] = [7, 11],
- source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
- lrelu_slope: float = 0.1,
- audio_limit: float = 0.99,
- f0_predictor: torch.nn.Module = None,
- ):
- super(HiFTGenerator, self).__init__()
- self.out_channels = 1
- self.nb_harmonics = nb_harmonics
- self.sampling_rate = sampling_rate
- self.istft_params = istft_params
- self.lrelu_slope = lrelu_slope
- self.audio_limit = audio_limit
- self.num_kernels = len(resblock_kernel_sizes)
- self.num_upsamples = len(upsample_rates)
- self.m_source = SourceModuleHnNSF(
- sampling_rate=sampling_rate,
- upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
- harmonic_num=nb_harmonics,
- sine_amp=nsf_alpha,
- add_noise_std=nsf_sigma,
- voiced_threshod=nsf_voiced_threshold)
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
- self.conv_pre = weight_norm(
- Conv1d(in_channels, base_channels, 7, 1, padding=3)
- )
- # Up
- self.ups = nn.ModuleList()
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
- self.ups.append(
- weight_norm(
- ConvTranspose1d(
- base_channels // (2**i),
- base_channels // (2**(i + 1)),
- k,
- u,
- padding=(k - u) // 2,
- )
- )
- )
- # Down
- self.source_downs = nn.ModuleList()
- self.source_resblocks = nn.ModuleList()
- downsample_rates = [1] + upsample_rates[::-1][:-1]
- downsample_cum_rates = np.cumprod(downsample_rates)
- for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
- source_resblock_dilation_sizes)):
- if u == 1:
- self.source_downs.append(
- Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
- )
- else:
- self.source_downs.append(
- Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
- )
- self.source_resblocks.append(
- ResBlock(base_channels // (2 ** (i + 1)), k, d)
- )
- self.resblocks = nn.ModuleList()
- for i in range(len(self.ups)):
- ch = base_channels // (2**(i + 1))
- for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
- self.resblocks.append(ResBlock(ch, k, d))
- self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
- self.ups.apply(init_weights)
- self.conv_post.apply(init_weights)
- self.reflection_pad = nn.ReflectionPad1d((1, 0))
- self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
- self.f0_predictor = f0_predictor
- def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
- har_source, _, _ = self.m_source(f0)
- return har_source.transpose(1, 2)
- def _stft(self, x):
- spec = torch.stft(
- x,
- self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
- return_complex=True)
- spec = torch.view_as_real(spec) # [B, F, TT, 2]
- return spec[..., 0], spec[..., 1]
- def _istft(self, magnitude, phase):
- magnitude = torch.clip(magnitude, max=1e2)
- real = magnitude * torch.cos(phase)
- img = magnitude * torch.sin(phase)
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
- return inverse_transform
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- f0 = self.f0_predictor(x)
- s = self._f02source(f0)
- s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
- x = self.conv_pre(x)
- for i in range(self.num_upsamples):
- x = F.leaky_relu(x, self.lrelu_slope)
- x = self.ups[i](x)
- if i == self.num_upsamples - 1:
- x = self.reflection_pad(x)
- # fusion
- si = self.source_downs[i](s_stft)
- si = self.source_resblocks[i](si)
- x = x + si
- xs = None
- for j in range(self.num_kernels):
- if xs is None:
- xs = self.resblocks[i * self.num_kernels + j](x)
- else:
- xs += self.resblocks[i * self.num_kernels + j](x)
- x = xs / self.num_kernels
- x = F.leaky_relu(x)
- x = self.conv_post(x)
- magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
- phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
- x = self._istft(magnitude, phase)
- x = torch.clamp(x, -self.audio_limit, self.audio_limit)
- return x
- def remove_weight_norm(self):
- print('Removing weight norm...')
- for l in self.ups:
- remove_weight_norm(l)
- for l in self.resblocks:
- l.remove_weight_norm()
- remove_weight_norm(self.conv_pre)
- remove_weight_norm(self.conv_post)
- self.source_module.remove_weight_norm()
- for l in self.source_downs:
- remove_weight_norm(l)
- for l in self.source_resblocks:
- l.remove_weight_norm()
- @torch.inference_mode()
- def inference(self, mel: torch.Tensor) -> torch.Tensor:
- return self.forward(x=mel)
|