Bläddra i källkod

use static_chunk_size in flow training

lyuxiang.lx 8 månader sedan
förälder
incheckning
d9ffd592f6
1 ändrade filer med 2 tillägg och 12 borttagningar
  1. 2 12
      cosyvoice/transformer/upsample_encoder.py

+ 2 - 12
cosyvoice/transformer/upsample_encoder.py

@@ -286,12 +286,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
             xs = self.global_cmvn(xs)
         xs, pos_emb, masks = self.embed(xs, masks)
         mask_pad = masks  # (B, 1, T/subsample_rate)
-        chunk_masks = add_optional_chunk_mask(xs, masks,
-                                              self.use_dynamic_chunk if streaming is True else False,
-                                              self.use_dynamic_left_chunk if streaming is True else False,
-                                              decoding_chunk_size if streaming is True else 0,
-                                              self.static_chunk_size if streaming is True else 0,
-                                              num_decoding_left_chunks if streaming is True else -1)
+        chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
         # lookahead + conformer encoder
         xs, _ = self.pre_lookahead_layer(xs)
         xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
@@ -304,12 +299,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
         masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
         xs, pos_emb, masks = self.up_embed(xs, masks)
         mask_pad = masks  # (B, 1, T/subsample_rate)
-        chunk_masks = add_optional_chunk_mask(xs, masks,
-                                              self.use_dynamic_chunk if streaming is True else False,
-                                              self.use_dynamic_left_chunk if streaming is True else False,
-                                              decoding_chunk_size if streaming is True else 0,
-                                              self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
-                                              num_decoding_left_chunks if streaming is True else -1)
+        chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
         xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
 
         if self.normalize_before: