|
@@ -30,6 +30,9 @@ class ConditionalCFM(BASECFM):
|
|
|
# Just change the architecture of the estimator here
|
|
# Just change the architecture of the estimator here
|
|
|
self.estimator = estimator
|
|
self.estimator = estimator
|
|
|
|
|
|
|
|
|
|
+ self.estimator_context = None
|
|
|
|
|
+ self.estimator_engine = None
|
|
|
|
|
+
|
|
|
@torch.inference_mode()
|
|
@torch.inference_mode()
|
|
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
|
"""Forward diffusion
|
|
"""Forward diffusion
|
|
@@ -50,7 +53,7 @@ class ConditionalCFM(BASECFM):
|
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
shape: (batch_size, n_feats, mel_timesteps)
|
|
|
"""
|
|
"""
|
|
|
z = torch.randn_like(mu) * temperature
|
|
z = torch.randn_like(mu) * temperature
|
|
|
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
|
|
|
|
|
|
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
|
if self.t_scheduler == 'cosine':
|
|
if self.t_scheduler == 'cosine':
|
|
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
@@ -71,6 +74,7 @@ class ConditionalCFM(BASECFM):
|
|
|
cond: Not used but kept for future purposes
|
|
cond: Not used but kept for future purposes
|
|
|
"""
|
|
"""
|
|
|
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
|
|
|
+ t = t.unsqueeze(dim=0)
|
|
|
|
|
|
|
|
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
|
# Or in future might add like a return_all_steps flag
|
|
# Or in future might add like a return_all_steps flag
|
|
@@ -96,13 +100,30 @@ class ConditionalCFM(BASECFM):
|
|
|
|
|
|
|
|
return sol[-1]
|
|
return sol[-1]
|
|
|
|
|
|
|
|
- # TODO
|
|
|
|
|
- def forward_estimator(self):
|
|
|
|
|
- if isinstance(self.estimator, trt):
|
|
|
|
|
|
|
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
|
|
|
|
|
+ if self.estimator_context is not None:
|
|
|
assert self.training is False, 'tensorrt cannot be used in training'
|
|
assert self.training is False, 'tensorrt cannot be used in training'
|
|
|
- return xxx
|
|
|
|
|
|
|
+ bs = x.shape[0]
|
|
|
|
|
+ hs = x.shape[1]
|
|
|
|
|
+ seq_len = x.shape[2]
|
|
|
|
|
+ # assert bs == 1 and hs == 80
|
|
|
|
|
+ ret = torch.empty_like(x)
|
|
|
|
|
+ self.estimator_context.set_input_shape("x", x.shape)
|
|
|
|
|
+ self.estimator_context.set_input_shape("mask", mask.shape)
|
|
|
|
|
+ self.estimator_context.set_input_shape("mu", mu.shape)
|
|
|
|
|
+ self.estimator_context.set_input_shape("t", t.shape)
|
|
|
|
|
+ self.estimator_context.set_input_shape("spks", spks.shape)
|
|
|
|
|
+ self.estimator_context.set_input_shape("cond", cond.shape)
|
|
|
|
|
+ bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(len(bindings)):
|
|
|
|
|
+ self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i])
|
|
|
|
|
+
|
|
|
|
|
+ handle = torch.cuda.current_stream().cuda_stream
|
|
|
|
|
+ self.estimator_context.execute_async_v3(stream_handle=handle)
|
|
|
|
|
+ return ret
|
|
|
else:
|
|
else:
|
|
|
- return self.estimator.forward
|
|
|
|
|
|
|
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
|
|
|
|
|
|
|
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
|
"""Computes diffusion loss
|
|
"""Computes diffusion loss
|