|
|
@@ -14,7 +14,7 @@
|
|
|
|
|
|
"""HIFI-GAN"""
|
|
|
|
|
|
-import typing as tp
|
|
|
+from typing import Dict, Optional, List
|
|
|
import numpy as np
|
|
|
from scipy.signal import get_window
|
|
|
import torch
|
|
|
@@ -46,7 +46,7 @@ class ResBlock(torch.nn.Module):
|
|
|
self,
|
|
|
channels: int = 512,
|
|
|
kernel_size: int = 3,
|
|
|
- dilations: tp.List[int] = [1, 3, 5],
|
|
|
+ dilations: List[int] = [1, 3, 5],
|
|
|
):
|
|
|
super(ResBlock, self).__init__()
|
|
|
self.convs1 = nn.ModuleList()
|
|
|
@@ -234,13 +234,13 @@ class HiFTGenerator(nn.Module):
|
|
|
nsf_alpha: float = 0.1,
|
|
|
nsf_sigma: float = 0.003,
|
|
|
nsf_voiced_threshold: float = 10,
|
|
|
- upsample_rates: tp.List[int] = [8, 8],
|
|
|
- upsample_kernel_sizes: tp.List[int] = [16, 16],
|
|
|
- istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
|
|
- resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
|
|
|
- resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
|
- source_resblock_kernel_sizes: tp.List[int] = [7, 11],
|
|
|
- source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
|
|
|
+ upsample_rates: List[int] = [8, 8],
|
|
|
+ upsample_kernel_sizes: List[int] = [16, 16],
|
|
|
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
|
|
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
|
|
|
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
|
+ source_resblock_kernel_sizes: List[int] = [7, 11],
|
|
|
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
|
|
lrelu_slope: float = 0.1,
|
|
|
audio_limit: float = 0.99,
|
|
|
f0_predictor: torch.nn.Module = None,
|
|
|
@@ -316,11 +316,19 @@ class HiFTGenerator(nn.Module):
|
|
|
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
|
|
self.f0_predictor = f0_predictor
|
|
|
|
|
|
- def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
|
|
|
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
|
|
-
|
|
|
- har_source, _, _ = self.m_source(f0)
|
|
|
- return har_source.transpose(1, 2)
|
|
|
+ def remove_weight_norm(self):
|
|
|
+ print('Removing weight norm...')
|
|
|
+ for l in self.ups:
|
|
|
+ remove_weight_norm(l)
|
|
|
+ for l in self.resblocks:
|
|
|
+ l.remove_weight_norm()
|
|
|
+ remove_weight_norm(self.conv_pre)
|
|
|
+ remove_weight_norm(self.conv_post)
|
|
|
+ self.m_source.remove_weight_norm()
|
|
|
+ for l in self.source_downs:
|
|
|
+ remove_weight_norm(l)
|
|
|
+ for l in self.source_resblocks:
|
|
|
+ l.remove_weight_norm()
|
|
|
|
|
|
def _stft(self, x):
|
|
|
spec = torch.stft(
|
|
|
@@ -338,14 +346,7 @@ class HiFTGenerator(nn.Module):
|
|
|
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
|
|
return inverse_transform
|
|
|
|
|
|
- def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
|
- f0 = self.f0_predictor(x)
|
|
|
- s = self._f02source(f0)
|
|
|
-
|
|
|
- # use cache_source to avoid glitch
|
|
|
- if cache_source.shape[2] != 0:
|
|
|
- s[:, :, :cache_source.shape[2]] = cache_source
|
|
|
-
|
|
|
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
|
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
|
|
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
|
|
|
|
|
@@ -377,22 +378,34 @@ class HiFTGenerator(nn.Module):
|
|
|
|
|
|
x = self._istft(magnitude, phase)
|
|
|
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
|
|
- return x, s
|
|
|
+ return x
|
|
|
|
|
|
- def remove_weight_norm(self):
|
|
|
- print('Removing weight norm...')
|
|
|
- for l in self.ups:
|
|
|
- remove_weight_norm(l)
|
|
|
- for l in self.resblocks:
|
|
|
- l.remove_weight_norm()
|
|
|
- remove_weight_norm(self.conv_pre)
|
|
|
- remove_weight_norm(self.conv_post)
|
|
|
- self.source_module.remove_weight_norm()
|
|
|
- for l in self.source_downs:
|
|
|
- remove_weight_norm(l)
|
|
|
- for l in self.source_resblocks:
|
|
|
- l.remove_weight_norm()
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ batch: dict,
|
|
|
+ device: torch.device,
|
|
|
+ ) -> Dict[str, Optional[torch.Tensor]]:
|
|
|
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
|
|
+ # mel->f0
|
|
|
+ f0 = self.f0_predictor(speech_feat)
|
|
|
+ # f0->source
|
|
|
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
|
|
+ s, _, _ = self.m_source(s)
|
|
|
+ s = s.transpose(1, 2)
|
|
|
+ # mel+source->speech
|
|
|
+ generated_speech = self.decode(x=speech_feat, s=s)
|
|
|
+ return generated_speech, f0
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
- def inference(self, mel: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
|
- return self.forward(x=mel, cache_source=cache_source)
|
|
|
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
|
+ # mel->f0
|
|
|
+ f0 = self.f0_predictor(speech_feat)
|
|
|
+ # f0->source
|
|
|
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
|
|
+ s, _, _ = self.m_source(s)
|
|
|
+ s = s.transpose(1, 2)
|
|
|
+ # use cache_source to avoid glitch
|
|
|
+ if cache_source.shape[2] != 0:
|
|
|
+ s[:, :, :cache_source.shape[2]] = cache_source
|
|
|
+ generated_speech = self.decode(x=speech_feat, s=s)
|
|
|
+ return generated_speech, s
|