|
|
@@ -136,41 +136,26 @@ class ConditionalCFM(BASECFM):
|
|
|
'mask': mask.cpu().numpy(),
|
|
|
'mu': mu.cpu().numpy(),
|
|
|
't': t.cpu().numpy(),
|
|
|
- 'spk': spks.cpu().numpy(),
|
|
|
- 'cond': cond.cpu().numpy(),
|
|
|
- 'mask_rand': torch.randn(1, 1, 1).numpy()
|
|
|
+ 'spks': spks.cpu().numpy(),
|
|
|
+ 'cond': cond.cpu().numpy()
|
|
|
}
|
|
|
output = self.estimator.run(None, ort_inputs)[0]
|
|
|
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
|
|
else:
|
|
|
- if not x.is_contiguous():
|
|
|
- x = x.contiguous()
|
|
|
- if not mask.is_contiguous():
|
|
|
- mask = mask.contiguous()
|
|
|
- if not mu.is_contiguous():
|
|
|
- mu = mu.contiguous()
|
|
|
- if not t.is_contiguous():
|
|
|
- t = t.contiguous()
|
|
|
- if not spks.is_contiguous():
|
|
|
- spks = spks.contiguous()
|
|
|
- if not cond.is_contiguous():
|
|
|
- cond = cond.contiguous()
|
|
|
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
|
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
|
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
|
self.estimator.set_input_shape('t', (2,))
|
|
|
- self.estimator.set_input_shape('spk', (2, 80))
|
|
|
+ self.estimator.set_input_shape('spks', (2, 80))
|
|
|
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
|
- self.estimator.set_input_shape('mask_rand', (1, 1, 1))
|
|
|
# run trt engine
|
|
|
- self.estimator.execute_v2([x.data_ptr(),
|
|
|
- mask.data_ptr(),
|
|
|
- mu.data_ptr(),
|
|
|
- t.data_ptr(),
|
|
|
- spks.data_ptr(),
|
|
|
- cond.data_ptr(),
|
|
|
- torch.randn(1, 1, 1).to(x.device).data_ptr(),
|
|
|
- x.data_ptr()])
|
|
|
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
|
|
|
+ mask.contiguous().data_ptr(),
|
|
|
+ mu.contiguous().data_ptr(),
|
|
|
+ t.contiguous().data_ptr(),
|
|
|
+ spks.contiguous().data_ptr(),
|
|
|
+ cond.contiguous().data_ptr(),
|
|
|
+ x.data_ptr()])
|
|
|
return x
|
|
|
|
|
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
|
@@ -241,7 +226,7 @@ class CausalConditionalCFM(ConditionalCFM):
|
|
|
"""
|
|
|
|
|
|
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
|
|
|
- if self.sp16 is True:
|
|
|
+ if self.fp16 is True:
|
|
|
z = z.half()
|
|
|
# fix prompt and overlap part mu and z
|
|
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|