|
|
@@ -286,12 +286,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|
|
xs = self.global_cmvn(xs)
|
|
|
xs, pos_emb, masks = self.embed(xs, masks)
|
|
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
|
- chunk_masks = add_optional_chunk_mask(xs, masks,
|
|
|
- self.use_dynamic_chunk if streaming is True else False,
|
|
|
- self.use_dynamic_left_chunk if streaming is True else False,
|
|
|
- decoding_chunk_size if streaming is True else 0,
|
|
|
- self.static_chunk_size if streaming is True else 0,
|
|
|
- num_decoding_left_chunks if streaming is True else -1)
|
|
|
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
|
|
# lookahead + conformer encoder
|
|
|
xs, _ = self.pre_lookahead_layer(xs)
|
|
|
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
|
@@ -304,12 +299,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|
|
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
|
|
xs, pos_emb, masks = self.up_embed(xs, masks)
|
|
|
mask_pad = masks # (B, 1, T/subsample_rate)
|
|
|
- chunk_masks = add_optional_chunk_mask(xs, masks,
|
|
|
- self.use_dynamic_chunk if streaming is True else False,
|
|
|
- self.use_dynamic_left_chunk if streaming is True else False,
|
|
|
- decoding_chunk_size if streaming is True else 0,
|
|
|
- self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
|
|
|
- num_decoding_left_chunks if streaming is True else -1)
|
|
|
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
|
|
|
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
|
|
|
|
|
if self.normalize_before:
|