|
|
@@ -126,6 +126,13 @@ class ConditionalCFM(BASECFM):
|
|
|
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
|
|
u = x1 - (1 - self.sigma_min) * z
|
|
|
|
|
|
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
|
|
+ if self.training_cfg_rate > 0:
|
|
|
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
|
|
+ mu = mu * cfg_mask.view(-1, 1, 1)
|
|
|
+ spks = spks * cfg_mask.view(-1, 1)
|
|
|
+ cond = cond * cfg_mask.view(-1, 1, 1)
|
|
|
+
|
|
|
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
|
|
|
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
|
|
return loss, y
|