|
|
@@ -127,6 +127,8 @@ class ConditionalCFM(BASECFM):
|
|
|
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
|
|
else:
|
|
|
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
|
|
+ # NOTE need to synchronize when switching stream
|
|
|
+ torch.cuda.current_stream().synchronize()
|
|
|
with stream:
|
|
|
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
|
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|