positionwise_feed_forward.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright (c) 2019 Shigeki Karita
  2. # 2020 Mobvoi Inc (Binbin Zhang)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Positionwise feed forward layer definition."""
  16. import torch
  17. class PositionwiseFeedForward(torch.nn.Module):
  18. """Positionwise feed forward layer.
  19. FeedForward are appied on each position of the sequence.
  20. The output dim is same with the input dim.
  21. Args:
  22. idim (int): Input dimenstion.
  23. hidden_units (int): The number of hidden units.
  24. dropout_rate (float): Dropout rate.
  25. activation (torch.nn.Module): Activation function
  26. """
  27. def __init__(
  28. self,
  29. idim: int,
  30. hidden_units: int,
  31. dropout_rate: float,
  32. activation: torch.nn.Module = torch.nn.ReLU(),
  33. ):
  34. """Construct a PositionwiseFeedForward object."""
  35. super(PositionwiseFeedForward, self).__init__()
  36. self.w_1 = torch.nn.Linear(idim, hidden_units)
  37. self.activation = activation
  38. self.dropout = torch.nn.Dropout(dropout_rate)
  39. self.w_2 = torch.nn.Linear(hidden_units, idim)
  40. def forward(self, xs: torch.Tensor) -> torch.Tensor:
  41. """Forward function.
  42. Args:
  43. xs: input tensor (B, L, D)
  44. Returns:
  45. output tensor, (B, L, D)
  46. """
  47. return self.w_2(self.dropout(self.activation(self.w_1(xs))))
  48. class MoEFFNLayer(torch.nn.Module):
  49. """
  50. Mixture of expert with Positionwise feed forward layer
  51. See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
  52. The output dim is same with the input dim.
  53. Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
  54. https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
  55. Args:
  56. n_expert: number of expert.
  57. n_expert_per_token: The actual number of experts used for each frame
  58. idim (int): Input dimenstion.
  59. hidden_units (int): The number of hidden units.
  60. dropout_rate (float): Dropout rate.
  61. activation (torch.nn.Module): Activation function
  62. """
  63. def __init__(
  64. self,
  65. n_expert: int,
  66. n_expert_per_token: int,
  67. idim: int,
  68. hidden_units: int,
  69. dropout_rate: float,
  70. activation: torch.nn.Module = torch.nn.ReLU(),
  71. ):
  72. super(MoEFFNLayer, self).__init__()
  73. self.gate = torch.nn.Linear(idim, n_expert, bias=False)
  74. self.experts = torch.nn.ModuleList(
  75. PositionwiseFeedForward(idim, hidden_units, dropout_rate,
  76. activation) for _ in range(n_expert))
  77. self.n_expert_per_token = n_expert_per_token
  78. def forward(self, xs: torch.Tensor) -> torch.Tensor:
  79. """Foward function.
  80. Args:
  81. xs: input tensor (B, L, D)
  82. Returns:
  83. output tensor, (B, L, D)
  84. """
  85. B, L, D = xs.size(
  86. ) # batch size, sequence length, embedding dimension (idim)
  87. xs = xs.view(-1, D) # (B*L, D)
  88. router = self.gate(xs) # (B*L, n_expert)
  89. logits, indices = torch.topk(
  90. router, self.n_expert_per_token
  91. ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
  92. weights = torch.nn.functional.softmax(
  93. logits, dim=1,
  94. dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
  95. output = torch.zeros_like(xs) # (B*L, D)
  96. for i, expert in enumerate(self.experts):
  97. mask = indices == i
  98. batch_idx, ith_expert = torch.where(mask)
  99. output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
  100. xs[batch_idx])
  101. return output.view(B, L, D)