|
|
@@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM):
|
|
|
sol = []
|
|
|
|
|
|
for step in range(1, len(t_span)):
|
|
|
- dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
|
|
+ dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
|
|
|
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
|
if self.inference_cfg_rate > 0:
|
|
|
- cfg_dphi_dt = self.estimator(
|
|
|
+ cfg_dphi_dt = self.forward_estimator(
|
|
|
x, mask,
|
|
|
torch.zeros_like(mu), t,
|
|
|
torch.zeros_like(spks) if spks is not None else None,
|
|
|
@@ -96,6 +96,14 @@ class ConditionalCFM(BASECFM):
|
|
|
|
|
|
return sol[-1]
|
|
|
|
|
|
+ # TODO
|
|
|
+ def forward_estimator(self):
|
|
|
+ if isinstance(self.estimator, trt):
|
|
|
+ assert self.training is False, 'tensorrt cannot be used in training'
|
|
|
+ return xxx
|
|
|
+ else:
|
|
|
+ return self.estimator.forward
|
|
|
+
|
|
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
|
"""Computes diffusion loss
|
|
|
|