|
|
@@ -32,7 +32,7 @@ class ConditionalCFM(BASECFM):
|
|
|
self.estimator = estimator
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
|
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None):
|
|
|
"""Forward diffusion
|
|
|
|
|
|
Args:
|
|
|
@@ -50,11 +50,26 @@ class ConditionalCFM(BASECFM):
|
|
|
sample: generated mel-spectrogram
|
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
|
"""
|
|
|
- z = torch.randn_like(mu) * temperature
|
|
|
+
|
|
|
+ if flow_cache is not None:
|
|
|
+ z_cache = flow_cache[0]
|
|
|
+ mu_cache = flow_cache[1]
|
|
|
+ z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature
|
|
|
+ z = torch.cat((z_cache, z), dim=2) # [B, 80, T]
|
|
|
+ mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T]
|
|
|
+ else:
|
|
|
+ z = torch.randn_like(mu) * temperature
|
|
|
+
|
|
|
+ next_cache_start = max(z.size(2) - required_cache_size, 0)
|
|
|
+ flow_cache = [
|
|
|
+ z[..., next_cache_start:],
|
|
|
+ mu[..., next_cache_start:]
|
|
|
+ ]
|
|
|
+
|
|
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
|
if self.t_scheduler == 'cosine':
|
|
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
|
- return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
|
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
|
|
|
|
|
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
|
|
"""
|