activation.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
  2. # 2020 Northwestern Polytechnical University (Pengcheng Guo)
  3. # 2020 Mobvoi Inc (Binbin Zhang)
  4. # 2024 Alibaba Inc (Xiang Lyu)
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """Swish() activation function for Conformer."""
  18. import torch
  19. from torch import nn, sin, pow
  20. from torch.nn import Parameter
  21. class Swish(torch.nn.Module):
  22. """Construct an Swish object."""
  23. def forward(self, x: torch.Tensor) -> torch.Tensor:
  24. """Return Swish activation function."""
  25. return x * torch.sigmoid(x)
  26. # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
  27. # LICENSE is in incl_licenses directory.
  28. class Snake(nn.Module):
  29. '''
  30. Implementation of a sine-based periodic activation function
  31. Shape:
  32. - Input: (B, C, T)
  33. - Output: (B, C, T), same shape as the input
  34. Parameters:
  35. - alpha - trainable parameter
  36. References:
  37. - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
  38. https://arxiv.org/abs/2006.08195
  39. Examples:
  40. >>> a1 = snake(256)
  41. >>> x = torch.randn(256)
  42. >>> x = a1(x)
  43. '''
  44. def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
  45. '''
  46. Initialization.
  47. INPUT:
  48. - in_features: shape of the input
  49. - alpha: trainable parameter
  50. alpha is initialized to 1 by default, higher values = higher-frequency.
  51. alpha will be trained along with the rest of your model.
  52. '''
  53. super(Snake, self).__init__()
  54. self.in_features = in_features
  55. # initialize alpha
  56. self.alpha_logscale = alpha_logscale
  57. if self.alpha_logscale: # log scale alphas initialized to zeros
  58. self.alpha = Parameter(torch.zeros(in_features) * alpha)
  59. else: # linear scale alphas initialized to ones
  60. self.alpha = Parameter(torch.ones(in_features) * alpha)
  61. self.alpha.requires_grad = alpha_trainable
  62. self.no_div_by_zero = 0.000000001
  63. def forward(self, x):
  64. '''
  65. Forward pass of the function.
  66. Applies the function to the input elementwise.
  67. Snake ∶= x + 1/a * sin^2 (xa)
  68. '''
  69. alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
  70. if self.alpha_logscale:
  71. alpha = torch.exp(alpha)
  72. x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
  73. return x