Browse Source

fix flow matching training for zero shot inference

lyuxiang.lx 1 year ago
parent
commit
9504c3f88b
1 changed files with 6 additions and 0 deletions
  1. 6 0
      cosyvoice/flow/flow.py

+ 6 - 0
cosyvoice/flow/flow.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 from typing import Dict, Optional
 import torch
 import torch.nn as nn
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
 
         # get conditions
         conds = torch.zeros(feat.shape, device=token.device)
+        for i, j in enumerate(feat_len):
+            if random.random() < 0.5:
+                continue
+            index = random.randint(0, int(0.3 * j))
+            conds[i, :index] = feat[i, :index]
         conds = conds.transpose(1, 2)
 
         mask = (~make_pad_mask(feat_len)).to(h)