1
0

encoder_layer.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Modified from ESPnet(https://github.com/espnet/espnet)
  16. """Encoder self-attention layer definition."""
  17. from typing import Optional, Tuple
  18. import torch
  19. from torch import nn
  20. class TransformerEncoderLayer(nn.Module):
  21. """Encoder layer module.
  22. Args:
  23. size (int): Input dimension.
  24. self_attn (torch.nn.Module): Self-attention module instance.
  25. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
  26. instance can be used as the argument.
  27. feed_forward (torch.nn.Module): Feed-forward module instance.
  28. `PositionwiseFeedForward`, instance can be used as the argument.
  29. dropout_rate (float): Dropout rate.
  30. normalize_before (bool):
  31. True: use layer_norm before each sub-block.
  32. False: to use layer_norm after each sub-block.
  33. """
  34. def __init__(
  35. self,
  36. size: int,
  37. self_attn: torch.nn.Module,
  38. feed_forward: torch.nn.Module,
  39. dropout_rate: float,
  40. normalize_before: bool = True,
  41. ):
  42. """Construct an EncoderLayer object."""
  43. super().__init__()
  44. self.self_attn = self_attn
  45. self.feed_forward = feed_forward
  46. self.norm1 = nn.LayerNorm(size, eps=1e-12)
  47. self.norm2 = nn.LayerNorm(size, eps=1e-12)
  48. self.dropout = nn.Dropout(dropout_rate)
  49. self.size = size
  50. self.normalize_before = normalize_before
  51. def forward(
  52. self,
  53. x: torch.Tensor,
  54. mask: torch.Tensor,
  55. pos_emb: torch.Tensor,
  56. mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  57. att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
  58. cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
  59. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  60. """Compute encoded features.
  61. Args:
  62. x (torch.Tensor): (#batch, time, size)
  63. mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
  64. (0, 0, 0) means fake mask.
  65. pos_emb (torch.Tensor): just for interface compatibility
  66. to ConformerEncoderLayer
  67. mask_pad (torch.Tensor): does not used in transformer layer,
  68. just for unified api with conformer.
  69. att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
  70. (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
  71. cnn_cache (torch.Tensor): Convolution cache in conformer layer
  72. (#batch=1, size, cache_t2), not used here, it's for interface
  73. compatibility to ConformerEncoderLayer.
  74. Returns:
  75. torch.Tensor: Output tensor (#batch, time, size).
  76. torch.Tensor: Mask tensor (#batch, time, time).
  77. torch.Tensor: att_cache tensor,
  78. (#batch=1, head, cache_t1 + time, d_k * 2).
  79. torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
  80. """
  81. residual = x
  82. if self.normalize_before:
  83. x = self.norm1(x)
  84. x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
  85. x = residual + self.dropout(x_att)
  86. if not self.normalize_before:
  87. x = self.norm1(x)
  88. residual = x
  89. if self.normalize_before:
  90. x = self.norm2(x)
  91. x = residual + self.dropout(self.feed_forward(x))
  92. if not self.normalize_before:
  93. x = self.norm2(x)
  94. fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
  95. return x, mask, new_att_cache, fake_cnn_cache
  96. class ConformerEncoderLayer(nn.Module):
  97. """Encoder layer module.
  98. Args:
  99. size (int): Input dimension.
  100. self_attn (torch.nn.Module): Self-attention module instance.
  101. `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
  102. instance can be used as the argument.
  103. feed_forward (torch.nn.Module): Feed-forward module instance.
  104. `PositionwiseFeedForward` instance can be used as the argument.
  105. feed_forward_macaron (torch.nn.Module): Additional feed-forward module
  106. instance.
  107. `PositionwiseFeedForward` instance can be used as the argument.
  108. conv_module (torch.nn.Module): Convolution module instance.
  109. `ConvlutionModule` instance can be used as the argument.
  110. dropout_rate (float): Dropout rate.
  111. normalize_before (bool):
  112. True: use layer_norm before each sub-block.
  113. False: use layer_norm after each sub-block.
  114. """
  115. def __init__(
  116. self,
  117. size: int,
  118. self_attn: torch.nn.Module,
  119. feed_forward: Optional[nn.Module] = None,
  120. feed_forward_macaron: Optional[nn.Module] = None,
  121. conv_module: Optional[nn.Module] = None,
  122. dropout_rate: float = 0.1,
  123. normalize_before: bool = True,
  124. ):
  125. """Construct an EncoderLayer object."""
  126. super().__init__()
  127. self.self_attn = self_attn
  128. self.feed_forward = feed_forward
  129. self.feed_forward_macaron = feed_forward_macaron
  130. self.conv_module = conv_module
  131. self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
  132. self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
  133. if feed_forward_macaron is not None:
  134. self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
  135. self.ff_scale = 0.5
  136. else:
  137. self.ff_scale = 1.0
  138. if self.conv_module is not None:
  139. self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
  140. self.norm_final = nn.LayerNorm(
  141. size, eps=1e-12) # for the final output of the block
  142. self.dropout = nn.Dropout(dropout_rate)
  143. self.size = size
  144. self.normalize_before = normalize_before
  145. def forward(
  146. self,
  147. x: torch.Tensor,
  148. mask: torch.Tensor,
  149. pos_emb: torch.Tensor,
  150. mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  151. att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
  152. cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
  153. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  154. """Compute encoded features.
  155. Args:
  156. x (torch.Tensor): (#batch, time, size)
  157. mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
  158. (0, 0, 0) means fake mask.
  159. pos_emb (torch.Tensor): positional encoding, must not be None
  160. for ConformerEncoderLayer.
  161. mask_pad (torch.Tensor): batch padding mask used for conv module.
  162. (#batch, 1,time), (0, 0, 0) means fake mask.
  163. att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
  164. (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
  165. cnn_cache (torch.Tensor): Convolution cache in conformer layer
  166. (#batch=1, size, cache_t2)
  167. Returns:
  168. torch.Tensor: Output tensor (#batch, time, size).
  169. torch.Tensor: Mask tensor (#batch, time, time).
  170. torch.Tensor: att_cache tensor,
  171. (#batch=1, head, cache_t1 + time, d_k * 2).
  172. torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
  173. """
  174. # whether to use macaron style
  175. if self.feed_forward_macaron is not None:
  176. residual = x
  177. if self.normalize_before:
  178. x = self.norm_ff_macaron(x)
  179. x = residual + self.ff_scale * self.dropout(
  180. self.feed_forward_macaron(x))
  181. if not self.normalize_before:
  182. x = self.norm_ff_macaron(x)
  183. # multi-headed self-attention module
  184. residual = x
  185. if self.normalize_before:
  186. x = self.norm_mha(x)
  187. x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
  188. att_cache)
  189. x = residual + self.dropout(x_att)
  190. if not self.normalize_before:
  191. x = self.norm_mha(x)
  192. # convolution module
  193. # Fake new cnn cache here, and then change it in conv_module
  194. new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
  195. if self.conv_module is not None:
  196. residual = x
  197. if self.normalize_before:
  198. x = self.norm_conv(x)
  199. x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
  200. x = residual + self.dropout(x)
  201. if not self.normalize_before:
  202. x = self.norm_conv(x)
  203. # feed forward module
  204. residual = x
  205. if self.normalize_before:
  206. x = self.norm_ff(x)
  207. x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
  208. if not self.normalize_before:
  209. x = self.norm_ff(x)
  210. if self.conv_module is not None:
  211. x = self.norm_final(x)
  212. return x, mask, new_att_cache, new_cnn_cache