|
@@ -396,6 +396,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|
|
encoders_kv_cache_list = []
|
|
encoders_kv_cache_list = []
|
|
|
for index, layer in enumerate(self.encoders):
|
|
for index, layer in enumerate(self.encoders):
|
|
|
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
|
|
xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
|
|
|
|
|
+ encoders_kv_cache_list.append(encoders_kv_cache_new)
|
|
|
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
|
|
encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
|
|
|
|
|
|
|
|
# upsample
|
|
# upsample
|
|
@@ -426,4 +427,4 @@ class UpsampleConformerEncoder(torch.nn.Module):
|
|
|
# Here we assume the mask is not changed in encoder layers, so just
|
|
# Here we assume the mask is not changed in encoder layers, so just
|
|
|
# return the masks before encoder layers, and the masks will be used
|
|
# return the masks before encoder layers, and the masks will be used
|
|
|
# for cross attention with decoder later
|
|
# for cross attention with decoder later
|
|
|
- return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache_new, upsample_offset, upsample_conv_cache, upsample_kv_cache_new)
|
|
|
|
|
|
|
+ return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)
|