decoder_layer.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) 2019 Shigeki Karita
  2. # 2020 Mobvoi Inc (Binbin Zhang)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Decoder self-attention layer definition."""
  16. from typing import Optional, Tuple
  17. import torch
  18. from torch import nn
  19. class DecoderLayer(nn.Module):
  20. """Single decoder layer module.
  21. Args:
  22. size (int): Input dimension.
  23. self_attn (torch.nn.Module): Self-attention module instance.
  24. `MultiHeadedAttention` instance can be used as the argument.
  25. src_attn (torch.nn.Module): Inter-attention module instance.
  26. `MultiHeadedAttention` instance can be used as the argument.
  27. If `None` is passed, Inter-attention is not used, such as
  28. CIF, GPT, and other decoder only model.
  29. feed_forward (torch.nn.Module): Feed-forward module instance.
  30. `PositionwiseFeedForward` instance can be used as the argument.
  31. dropout_rate (float): Dropout rate.
  32. normalize_before (bool):
  33. True: use layer_norm before each sub-block.
  34. False: to use layer_norm after each sub-block.
  35. """
  36. def __init__(
  37. self,
  38. size: int,
  39. self_attn: nn.Module,
  40. src_attn: Optional[nn.Module],
  41. feed_forward: nn.Module,
  42. dropout_rate: float,
  43. normalize_before: bool = True,
  44. ):
  45. """Construct an DecoderLayer object."""
  46. super().__init__()
  47. self.size = size
  48. self.self_attn = self_attn
  49. self.src_attn = src_attn
  50. self.feed_forward = feed_forward
  51. self.norm1 = nn.LayerNorm(size, eps=1e-5)
  52. self.norm2 = nn.LayerNorm(size, eps=1e-5)
  53. self.norm3 = nn.LayerNorm(size, eps=1e-5)
  54. self.dropout = nn.Dropout(dropout_rate)
  55. self.normalize_before = normalize_before
  56. def forward(
  57. self,
  58. tgt: torch.Tensor,
  59. tgt_mask: torch.Tensor,
  60. memory: torch.Tensor,
  61. memory_mask: torch.Tensor,
  62. cache: Optional[torch.Tensor] = None
  63. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  64. """Compute decoded features.
  65. Args:
  66. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  67. tgt_mask (torch.Tensor): Mask for input tensor
  68. (#batch, maxlen_out).
  69. memory (torch.Tensor): Encoded memory
  70. (#batch, maxlen_in, size).
  71. memory_mask (torch.Tensor): Encoded memory mask
  72. (#batch, maxlen_in).
  73. cache (torch.Tensor): cached tensors.
  74. (#batch, maxlen_out - 1, size).
  75. Returns:
  76. torch.Tensor: Output tensor (#batch, maxlen_out, size).
  77. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  78. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  79. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  80. """
  81. residual = tgt
  82. if self.normalize_before:
  83. tgt = self.norm1(tgt)
  84. if cache is None:
  85. tgt_q = tgt
  86. tgt_q_mask = tgt_mask
  87. else:
  88. # compute only the last frame query keeping dim: max_time_out -> 1
  89. assert cache.shape == (
  90. tgt.shape[0],
  91. tgt.shape[1] - 1,
  92. self.size,
  93. ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
  94. tgt_q = tgt[:, -1:, :]
  95. residual = residual[:, -1:, :]
  96. tgt_q_mask = tgt_mask[:, -1:, :]
  97. x = residual + self.dropout(
  98. self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
  99. if not self.normalize_before:
  100. x = self.norm1(x)
  101. if self.src_attn is not None:
  102. residual = x
  103. if self.normalize_before:
  104. x = self.norm2(x)
  105. x = residual + self.dropout(
  106. self.src_attn(x, memory, memory, memory_mask)[0])
  107. if not self.normalize_before:
  108. x = self.norm2(x)
  109. residual = x
  110. if self.normalize_before:
  111. x = self.norm3(x)
  112. x = residual + self.dropout(self.feed_forward(x))
  113. if not self.normalize_before:
  114. x = self.norm3(x)
  115. if cache is not None:
  116. x = torch.cat([cache, x], dim=1)
  117. return x, tgt_mask, memory, memory_mask