| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
- # 2024 Alibaba Inc (Xiang Lyu)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Modified from ESPnet(https://github.com/espnet/espnet)
- """Positonal Encoding Module."""
- import math
- from typing import Tuple, Union
- import torch
- import torch.nn.functional as F
- import numpy as np
- class PositionalEncoding(torch.nn.Module):
- """Positional encoding.
- :param int d_model: embedding dim
- :param float dropout_rate: dropout rate
- :param int max_len: maximum input length
- PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
- PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
- """
- def __init__(self,
- d_model: int,
- dropout_rate: float,
- max_len: int = 5000,
- reverse: bool = False):
- """Construct an PositionalEncoding object."""
- super().__init__()
- self.d_model = d_model
- self.xscale = math.sqrt(self.d_model)
- self.dropout = torch.nn.Dropout(p=dropout_rate)
- self.max_len = max_len
- self.pe = torch.zeros(self.max_len, self.d_model)
- position = torch.arange(0, self.max_len,
- dtype=torch.float32).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, self.d_model, 2, dtype=torch.float32) *
- -(math.log(10000.0) / self.d_model))
- self.pe[:, 0::2] = torch.sin(position * div_term)
- self.pe[:, 1::2] = torch.cos(position * div_term)
- self.pe = self.pe.unsqueeze(0)
- def forward(self,
- x: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0) \
- -> Tuple[torch.Tensor, torch.Tensor]:
- """Add positional encoding.
- Args:
- x (torch.Tensor): Input. Its shape is (batch, time, ...)
- offset (int, torch.tensor): position offset
- Returns:
- torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
- torch.Tensor: for compatibility to RelPositionalEncoding
- """
- self.pe = self.pe.to(x.device)
- pos_emb = self.position_encoding(offset, x.size(1), False)
- x = x * self.xscale + pos_emb
- return self.dropout(x), self.dropout(pos_emb)
- def position_encoding(self,
- offset: Union[int, torch.Tensor],
- size: int,
- apply_dropout: bool = True) -> torch.Tensor:
- """ For getting encoding in a streaming fashion
- Attention!!!!!
- we apply dropout only once at the whole utterance level in a none
- streaming way, but will call this function several times with
- increasing input size in a streaming scenario, so the dropout will
- be applied several times.
- Args:
- offset (int or torch.tensor): start offset
- size (int): required size of position encoding
- Returns:
- torch.Tensor: Corresponding encoding
- """
- # How to subscript a Union type:
- # https://github.com/pytorch/pytorch/issues/69434
- if isinstance(offset, int):
- assert offset + size <= self.max_len
- pos_emb = self.pe[:, offset:offset + size]
- elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
- assert offset + size <= self.max_len
- pos_emb = self.pe[:, offset:offset + size]
- else: # for batched streaming decoding on GPU
- assert torch.max(offset) + size <= self.max_len
- index = offset.unsqueeze(1) + \
- torch.arange(0, size).to(offset.device) # B X T
- flag = index > 0
- # remove negative offset
- index = index * flag
- pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
- if apply_dropout:
- pos_emb = self.dropout(pos_emb)
- return pos_emb
- class RelPositionalEncoding(PositionalEncoding):
- """Relative positional encoding module.
- See : Appendix B in https://arxiv.org/abs/1901.02860
- Args:
- d_model (int): Embedding dimension.
- dropout_rate (float): Dropout rate.
- max_len (int): Maximum input length.
- """
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
- """Initialize class."""
- super().__init__(d_model, dropout_rate, max_len, reverse=True)
- def forward(self,
- x: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0) \
- -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute positional encoding.
- Args:
- x (torch.Tensor): Input tensor (batch, time, `*`).
- Returns:
- torch.Tensor: Encoded tensor (batch, time, `*`).
- torch.Tensor: Positional embedding tensor (1, time, `*`).
- """
- self.pe = self.pe.to(x.device)
- x = x * self.xscale
- pos_emb = self.position_encoding(offset, x.size(1), False)
- return self.dropout(x), self.dropout(pos_emb)
- class WhisperPositionalEncoding(PositionalEncoding):
- """ Sinusoids position encoding used in openai-whisper.encoder
- """
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
- super().__init__(d_model, dropout_rate, max_len)
- self.xscale = 1.0
- log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
- inv_timescales = torch.exp(-log_timescale_increment *
- torch.arange(d_model // 2))
- scaled_time = torch.arange(max_len)[:, np.newaxis] * \
- inv_timescales[np.newaxis, :]
- pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
- delattr(self, "pe")
- self.register_buffer("pe", pe.unsqueeze(0))
- class LearnablePositionalEncoding(PositionalEncoding):
- """ Learnable position encoding used in openai-whisper.decoder
- """
- def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
- super().__init__(d_model, dropout_rate, max_len)
- # NOTE(xcsong): overwrite self.pe & self.xscale
- self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
- self.xscale = 1.0
- class NoPositionalEncoding(torch.nn.Module):
- """ No position encoding
- """
- def __init__(self, d_model: int, dropout_rate: float):
- super().__init__()
- self.d_model = d_model
- self.dropout = torch.nn.Dropout(p=dropout_rate)
- def forward(self,
- x: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0) \
- -> Tuple[torch.Tensor, torch.Tensor]:
- """ Just return zero vector for interface compatibility
- """
- pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
- return self.dropout(x), pos_emb
- def position_encoding(self, offset: Union[int, torch.Tensor],
- size: int) -> torch.Tensor:
- return torch.zeros(1, size, self.d_model)
- class EspnetRelPositionalEncoding(torch.nn.Module):
- """Relative positional encoding module (new implementation).
- Details can be found in https://github.com/espnet/espnet/pull/2816.
- See : Appendix B in https://arxiv.org/abs/1901.02860
- Args:
- d_model (int): Embedding dimension.
- dropout_rate (float): Dropout rate.
- max_len (int): Maximum input length.
- """
- def __init__(self, d_model, dropout_rate, max_len=5000):
- """Construct an PositionalEncoding object."""
- super(EspnetRelPositionalEncoding, self).__init__()
- self.d_model = d_model
- self.xscale = math.sqrt(self.d_model)
- self.dropout = torch.nn.Dropout(p=dropout_rate)
- self.pe = None
- self.extend_pe(torch.tensor(0.0).expand(1, max_len))
- def extend_pe(self, x):
- """Reset the positional encodings."""
- if self.pe is not None:
- # self.pe contains both positive and negative parts
- # the length of self.pe is 2 * input_len - 1
- if self.pe.size(1) >= x.size(1) * 2 - 1:
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
- return
- # Suppose `i` means to the position of query vecotr and `j` means the
- # position of key vector. We use position relative positions when keys
- # are to the left (i>j) and negative relative positions otherwise (i<j).
- pe_positive = torch.zeros(x.size(1), self.d_model)
- pe_negative = torch.zeros(x.size(1), self.d_model)
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, self.d_model, 2, dtype=torch.float32)
- * -(math.log(10000.0) / self.d_model)
- )
- pe_positive[:, 0::2] = torch.sin(position * div_term)
- pe_positive[:, 1::2] = torch.cos(position * div_term)
- pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
- pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
- # Reserve the order of positive indices and concat both positive and
- # negative indices. This is used to support the shifting trick
- # as in https://arxiv.org/abs/1901.02860
- pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
- pe_negative = pe_negative[1:].unsqueeze(0)
- pe = torch.cat([pe_positive, pe_negative], dim=1)
- self.pe = pe.to(device=x.device, dtype=x.dtype)
- def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
- """Add positional encoding.
- Args:
- x (torch.Tensor): Input tensor (batch, time, `*`).
- Returns:
- torch.Tensor: Encoded tensor (batch, time, `*`).
- """
- self.extend_pe(x)
- x = x * self.xscale
- pos_emb = self.position_encoding(size=x.size(1), offset=offset)
- return self.dropout(x), self.dropout(pos_emb)
- def position_encoding(self,
- offset: Union[int, torch.Tensor],
- size: int) -> torch.Tensor:
- """ For getting encoding in a streaming fashion
- Attention!!!!!
- we apply dropout only once at the whole utterance level in a none
- streaming way, but will call this function several times with
- increasing input size in a streaming scenario, so the dropout will
- be applied several times.
- Args:
- offset (int or torch.tensor): start offset
- size (int): required size of position encoding
- Returns:
- torch.Tensor: Corresponding encoding
- """
- pos_emb = self.pe[
- :,
- self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
- ]
- return pos_emb
|