| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
- # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Modified from ESPnet(https://github.com/espnet/espnet)
- """Encoder self-attention layer definition."""
- from typing import Optional, Tuple
- import torch
- from torch import nn
- class TransformerEncoderLayer(nn.Module):
- """Encoder layer module.
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
- instance can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward`, instance can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool):
- True: use layer_norm before each sub-block.
- False: to use layer_norm after each sub-block.
- """
- def __init__(
- self,
- size: int,
- self_attn: torch.nn.Module,
- feed_forward: torch.nn.Module,
- dropout_rate: float,
- normalize_before: bool = True,
- ):
- """Construct an EncoderLayer object."""
- super().__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.norm1 = nn.LayerNorm(size, eps=1e-5)
- self.norm2 = nn.LayerNorm(size, eps=1e-5)
- self.dropout = nn.Dropout(dropout_rate)
- self.size = size
- self.normalize_before = normalize_before
- def forward(
- self,
- x: torch.Tensor,
- mask: torch.Tensor,
- pos_emb: torch.Tensor,
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
- att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
- cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """Compute encoded features.
- Args:
- x (torch.Tensor): (#batch, time, size)
- mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
- (0, 0, 0) means fake mask.
- pos_emb (torch.Tensor): just for interface compatibility
- to ConformerEncoderLayer
- mask_pad (torch.Tensor): does not used in transformer layer,
- just for unified api with conformer.
- att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
- (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
- cnn_cache (torch.Tensor): Convolution cache in conformer layer
- (#batch=1, size, cache_t2), not used here, it's for interface
- compatibility to ConformerEncoderLayer.
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time, time).
- torch.Tensor: att_cache tensor,
- (#batch=1, head, cache_t1 + time, d_k * 2).
- torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
- """
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
- x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
- x = residual + self.dropout(x_att)
- if not self.normalize_before:
- x = self.norm1(x)
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
- fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
- return x, mask, new_att_cache, fake_cnn_cache
- class ConformerEncoderLayer(nn.Module):
- """Encoder layer module.
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
- instance can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward` instance can be used as the argument.
- feed_forward_macaron (torch.nn.Module): Additional feed-forward module
- instance.
- `PositionwiseFeedForward` instance can be used as the argument.
- conv_module (torch.nn.Module): Convolution module instance.
- `ConvlutionModule` instance can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool):
- True: use layer_norm before each sub-block.
- False: use layer_norm after each sub-block.
- """
- def __init__(
- self,
- size: int,
- self_attn: torch.nn.Module,
- feed_forward: Optional[nn.Module] = None,
- feed_forward_macaron: Optional[nn.Module] = None,
- conv_module: Optional[nn.Module] = None,
- dropout_rate: float = 0.1,
- normalize_before: bool = True,
- ):
- """Construct an EncoderLayer object."""
- super().__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.feed_forward_macaron = feed_forward_macaron
- self.conv_module = conv_module
- self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
- self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
- if feed_forward_macaron is not None:
- self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
- self.ff_scale = 0.5
- else:
- self.ff_scale = 1.0
- if self.conv_module is not None:
- self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
- self.norm_final = nn.LayerNorm(
- size, eps=1e-5) # for the final output of the block
- self.dropout = nn.Dropout(dropout_rate)
- self.size = size
- self.normalize_before = normalize_before
- def forward(
- self,
- x: torch.Tensor,
- mask: torch.Tensor,
- pos_emb: torch.Tensor,
- mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
- att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
- cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """Compute encoded features.
- Args:
- x (torch.Tensor): (#batch, time, size)
- mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
- (0, 0, 0) means fake mask.
- pos_emb (torch.Tensor): positional encoding, must not be None
- for ConformerEncoderLayer.
- mask_pad (torch.Tensor): batch padding mask used for conv module.
- (#batch, 1,time), (0, 0, 0) means fake mask.
- att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
- (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
- cnn_cache (torch.Tensor): Convolution cache in conformer layer
- (#batch=1, size, cache_t2)
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time, time).
- torch.Tensor: att_cache tensor,
- (#batch=1, head, cache_t1 + time, d_k * 2).
- torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
- """
- # whether to use macaron style
- if self.feed_forward_macaron is not None:
- residual = x
- if self.normalize_before:
- x = self.norm_ff_macaron(x)
- x = residual + self.ff_scale * self.dropout(
- self.feed_forward_macaron(x))
- if not self.normalize_before:
- x = self.norm_ff_macaron(x)
- # multi-headed self-attention module
- residual = x
- if self.normalize_before:
- x = self.norm_mha(x)
- x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
- att_cache)
- x = residual + self.dropout(x_att)
- if not self.normalize_before:
- x = self.norm_mha(x)
- # convolution module
- # Fake new cnn cache here, and then change it in conv_module
- new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
- if self.conv_module is not None:
- residual = x
- if self.normalize_before:
- x = self.norm_conv(x)
- x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
- x = residual + self.dropout(x)
- if not self.normalize_before:
- x = self.norm_conv(x)
- # feed forward module
- residual = x
- if self.normalize_before:
- x = self.norm_ff(x)
- x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm_ff(x)
- if self.conv_module is not None:
- x = self.norm_final(x)
- return x, mask, new_att_cache, new_cnn_cache
|