| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
- # 2020 Northwestern Polytechnical University (Pengcheng Guo)
- # 2020 Mobvoi Inc (Binbin Zhang)
- # 2024 Alibaba Inc (Xiang Lyu)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Swish() activation function for Conformer."""
- import torch
- from torch import nn, sin, pow
- from torch.nn import Parameter
- class Swish(torch.nn.Module):
- """Construct an Swish object."""
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Return Swish activation function."""
- return x * torch.sigmoid(x)
- # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
- # LICENSE is in incl_licenses directory.
- class Snake(nn.Module):
- '''
- Implementation of a sine-based periodic activation function
- Shape:
- - Input: (B, C, T)
- - Output: (B, C, T), same shape as the input
- Parameters:
- - alpha - trainable parameter
- References:
- - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
- https://arxiv.org/abs/2006.08195
- Examples:
- >>> a1 = snake(256)
- >>> x = torch.randn(256)
- >>> x = a1(x)
- '''
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
- '''
- Initialization.
- INPUT:
- - in_features: shape of the input
- - alpha: trainable parameter
- alpha is initialized to 1 by default, higher values = higher-frequency.
- alpha will be trained along with the rest of your model.
- '''
- super(Snake, self).__init__()
- self.in_features = in_features
- # initialize alpha
- self.alpha_logscale = alpha_logscale
- if self.alpha_logscale: # log scale alphas initialized to zeros
- self.alpha = Parameter(torch.zeros(in_features) * alpha)
- else: # linear scale alphas initialized to ones
- self.alpha = Parameter(torch.ones(in_features) * alpha)
- self.alpha.requires_grad = alpha_trainable
- self.no_div_by_zero = 0.000000001
- def forward(self, x):
- '''
- Forward pass of the function.
- Applies the function to the input elementwise.
- Snake ∶= x + 1/a * sin^2 (xa)
- '''
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
- if self.alpha_logscale:
- alpha = torch.exp(alpha)
- x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
- return x
|