|
@@ -92,7 +92,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
|
|
|
|
|
|
mask = (~make_pad_mask(feat_len)).to(h)
|
|
mask = (~make_pad_mask(feat_len)).to(h)
|
|
|
# NOTE this is unnecessary, feat/h already same shape
|
|
# NOTE this is unnecessary, feat/h already same shape
|
|
|
- feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
|
|
|
|
loss, _ = self.decoder.compute_loss(
|
|
loss, _ = self.decoder.compute_loss(
|
|
|
feat.transpose(1, 2).contiguous(),
|
|
feat.transpose(1, 2).contiguous(),
|
|
|
mask.unsqueeze(1),
|
|
mask.unsqueeze(1),
|
|
@@ -214,7 +213,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
|
h = self.encoder_proj(h)
|
|
h = self.encoder_proj(h)
|
|
|
|
|
|
|
|
# get conditions
|
|
# get conditions
|
|
|
- feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
|
|
|
|
conds = torch.zeros(feat.shape, device=token.device)
|
|
conds = torch.zeros(feat.shape, device=token.device)
|
|
|
for i, j in enumerate(feat_len):
|
|
for i, j in enumerate(feat_len):
|
|
|
if random.random() < 0.5:
|
|
if random.random() < 0.5:
|