|
|
@@ -30,10 +30,6 @@ class ConditionalCFM(BASECFM):
|
|
|
# Just change the architecture of the estimator here
|
|
|
self.estimator = estimator
|
|
|
|
|
|
- self.estimator_context = None
|
|
|
- self.estimator_engine = None
|
|
|
- self.is_saved = None
|
|
|
-
|
|
|
@torch.inference_mode()
|
|
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
|
|
"""Forward diffusion
|
|
|
@@ -102,7 +98,11 @@ class ConditionalCFM(BASECFM):
|
|
|
return sol[-1]
|
|
|
|
|
|
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
|
|
- if self.estimator_context is not None:
|
|
|
+
|
|
|
+ if not isinstance(self.estimator, torch.nn.Module):
|
|
|
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
|
|
|
+
|
|
|
+ else:
|
|
|
assert self.training is False, 'tensorrt cannot be used in training'
|
|
|
bs = x.shape[0]
|
|
|
hs = x.shape[1]
|
|
|
@@ -116,50 +116,14 @@ class ConditionalCFM(BASECFM):
|
|
|
self.estimator_context.set_input_shape("spks", spks.shape)
|
|
|
self.estimator_context.set_input_shape("cond", cond.shape)
|
|
|
bindings = [x.data_ptr(), mask.data_ptr(), mu.data_ptr(), t.data_ptr(), spks.data_ptr(), cond.data_ptr(), ret.data_ptr()]
|
|
|
+ names = ['x', 'mask', 'mu', 't', 'spks', 'cond', 'estimator_out']
|
|
|
|
|
|
for i in range(len(bindings)):
|
|
|
- self.estimator_context.set_tensor_address(self.estimator_engine.get_tensor_name(i), bindings[i])
|
|
|
+ self.estimator.set_tensor_address(names[i], bindings[i])
|
|
|
|
|
|
handle = torch.cuda.current_stream().cuda_stream
|
|
|
- self.estimator_context.execute_async_v3(stream_handle=handle)
|
|
|
+ self.estimator.execute_async_v3(stream_handle=handle)
|
|
|
return ret
|
|
|
- else:
|
|
|
-
|
|
|
- if self.is_saved == None:
|
|
|
- self.is_saved = True
|
|
|
- output = self.estimator.forward(x, mask, mu, t, spks, cond)
|
|
|
- torch.save(x, "x.pt")
|
|
|
- torch.save(mask, "mask.pt")
|
|
|
- torch.save(mu, "mu.pt")
|
|
|
- torch.save(t, "t.pt")
|
|
|
- torch.save(spks, "spks.pt")
|
|
|
- torch.save(cond, "cond.pt")
|
|
|
- torch.save(output, "output.pt")
|
|
|
- dummy_input = (x, mask, mu, t, spks, cond)
|
|
|
- torch.onnx.export(
|
|
|
- self.estimator,
|
|
|
- dummy_input,
|
|
|
- "estimator_fp32.onnx",
|
|
|
- export_params=True,
|
|
|
- opset_version=17,
|
|
|
- do_constant_folding=True,
|
|
|
- input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
|
|
- output_names=['output'],
|
|
|
- dynamic_axes={
|
|
|
- 'x': {2: 'seq_len'},
|
|
|
- 'mask': {2: 'seq_len'},
|
|
|
- 'mu': {2: 'seq_len'},
|
|
|
- 'cond': {2: 'seq_len'},
|
|
|
- 'output': {2: 'seq_len'},
|
|
|
- }
|
|
|
- )
|
|
|
- # print("x, x.shape", x, x.shape)
|
|
|
- # print("mask, mask.shape", mask, mask.shape)
|
|
|
- # print("mu, mu.shape", mu, mu.shape)
|
|
|
- # print("t, t.shape", t, t.shape)
|
|
|
- # print("spks, spks.shape", spks, spks.shape)
|
|
|
- # print("cond, cond.shape", cond, cond.shape)
|
|
|
- return self.estimator.forward(x, mask, mu, t, spks, cond)
|
|
|
|
|
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
|
"""Computes diffusion loss
|