|
|
@@ -11,6 +11,7 @@
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
+import threading
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from matcha.models.components.flow_matching import BASECFM
|
|
|
@@ -30,6 +31,7 @@ class ConditionalCFM(BASECFM):
|
|
|
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
|
|
# Just change the architecture of the estimator here
|
|
|
self.estimator = estimator
|
|
|
+ self.lock = threading.Lock()
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
|
|
@@ -120,7 +122,25 @@ class ConditionalCFM(BASECFM):
|
|
|
return sol[-1].float()
|
|
|
|
|
|
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
|
|
- return self.estimator.forward(x, mask, mu, t, spks, cond)
|
|
|
+ if isinstance(self.estimator, torch.nn.Module):
|
|
|
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
|
|
|
+ else:
|
|
|
+ with self.lock:
|
|
|
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
|
|
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
|
|
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
|
|
+ self.estimator.set_input_shape('t', (2,))
|
|
|
+ self.estimator.set_input_shape('spks', (2, 80))
|
|
|
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
|
|
+ # run trt engine
|
|
|
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
|
|
|
+ mask.contiguous().data_ptr(),
|
|
|
+ mu.contiguous().data_ptr(),
|
|
|
+ t.contiguous().data_ptr(),
|
|
|
+ spks.contiguous().data_ptr(),
|
|
|
+ cond.contiguous().data_ptr(),
|
|
|
+ x.data_ptr()])
|
|
|
+ return x
|
|
|
|
|
|
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
|
|
"""Computes diffusion loss
|