| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383 |
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
- # 2024 Alibaba Inc (Xiang Lyu)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Modified from ESPnet(https://github.com/espnet/espnet)
- """Subsampling layer definition."""
- from typing import Tuple, Union
- import torch
- class BaseSubsampling(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.right_context = 0
- self.subsampling_rate = 1
- def position_encoding(self, offset: Union[int, torch.Tensor],
- size: int) -> torch.Tensor:
- return self.pos_enc.position_encoding(offset, size)
- class EmbedinigNoSubsampling(BaseSubsampling):
- """Embedding input without subsampling
- """
- def __init__(self, idim: int, odim: int, dropout_rate: float,
- pos_enc_class: torch.nn.Module):
- super().__init__()
- self.embed = torch.nn.Embedding(idim, odim)
- self.pos_enc = pos_enc_class
- def forward(
- self,
- x: torch.Tensor,
- x_mask: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Input x.
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
- Returns:
- torch.Tensor: linear input tensor (#batch, time', odim),
- where time' = time .
- torch.Tensor: linear input mask (#batch, 1, time'),
- where time' = time .
- """
- x = self.embed(x)
- x, pos_emb = self.pos_enc(x, offset)
- return x, pos_emb, x_mask
- class LinearNoSubsampling(BaseSubsampling):
- """Linear transform the input without subsampling
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, idim: int, odim: int, dropout_rate: float,
- pos_enc_class: torch.nn.Module):
- """Construct an linear object."""
- super().__init__()
- self.out = torch.nn.Sequential(
- torch.nn.Linear(idim, odim),
- torch.nn.LayerNorm(odim, eps=1e-5),
- torch.nn.Dropout(dropout_rate),
- )
- self.pos_enc = pos_enc_class
- self.right_context = 0
- self.subsampling_rate = 1
- def forward(
- self,
- x: torch.Tensor,
- x_mask: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Input x.
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
- Returns:
- torch.Tensor: linear input tensor (#batch, time', odim),
- where time' = time .
- torch.Tensor: linear input mask (#batch, 1, time'),
- where time' = time .
- """
- x = self.out(x)
- x, pos_emb = self.pos_enc(x, offset)
- return x, pos_emb, x_mask
- class Conv1dSubsampling2(BaseSubsampling):
- """Convolutional 1D subsampling (to 1/2 length).
- It is designed for Whisper, ref:
- https://github.com/openai/whisper/blob/main/whisper/model.py
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, idim: int, odim: int, dropout_rate: float,
- pos_enc_class: torch.nn.Module):
- """Construct an Conv1dSubsampling2 object."""
- super().__init__()
- self.conv = torch.nn.Sequential(
- torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
- torch.nn.GELU(),
- torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
- torch.nn.GELU(),
- )
- self.pos_enc = pos_enc_class
- # The right context for every conv layer is computed by:
- # (kernel_size - 1) * frame_rate_of_this_layer
- self.subsampling_rate = 2
- # 4 = (3 - 1) * 1 + (3 - 1) * 1
- self.right_context = 4
- def forward(
- self,
- x: torch.Tensor,
- x_mask: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Subsample x.
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
- Returns:
- torch.Tensor: Subsampled tensor (#batch, time', odim),
- where time' = time // 2.
- torch.Tensor: Subsampled mask (#batch, 1, time'),
- where time' = time // 2.
- torch.Tensor: positional encoding
- """
- time = x.size(1)
- x = x.transpose(1, 2) # (b, f, t)
- x = self.conv(x)
- x = x.transpose(1, 2) # (b, t, f)
- x, pos_emb = self.pos_enc(x, offset)
- return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
- class Conv2dSubsampling4(BaseSubsampling):
- """Convolutional 2D subsampling (to 1/4 length).
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, idim: int, odim: int, dropout_rate: float,
- pos_enc_class: torch.nn.Module):
- """Construct an Conv2dSubsampling4 object."""
- super().__init__()
- self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, odim, 3, 2),
- torch.nn.ReLU(),
- torch.nn.Conv2d(odim, odim, 3, 2),
- torch.nn.ReLU(),
- )
- self.out = torch.nn.Sequential(
- torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
- self.pos_enc = pos_enc_class
- # The right context for every conv layer is computed by:
- # (kernel_size - 1) * frame_rate_of_this_layer
- self.subsampling_rate = 4
- # 6 = (3 - 1) * 1 + (3 - 1) * 2
- self.right_context = 6
- def forward(
- self,
- x: torch.Tensor,
- x_mask: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Subsample x.
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
- Returns:
- torch.Tensor: Subsampled tensor (#batch, time', odim),
- where time' = time // 4.
- torch.Tensor: Subsampled mask (#batch, 1, time'),
- where time' = time // 4.
- torch.Tensor: positional encoding
- """
- x = x.unsqueeze(1) # (b, c=1, t, f)
- x = self.conv(x)
- b, c, t, f = x.size()
- x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- x, pos_emb = self.pos_enc(x, offset)
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
- class Conv2dSubsampling6(BaseSubsampling):
- """Convolutional 2D subsampling (to 1/6 length).
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- pos_enc (torch.nn.Module): Custom position encoding layer.
- """
- def __init__(self, idim: int, odim: int, dropout_rate: float,
- pos_enc_class: torch.nn.Module):
- """Construct an Conv2dSubsampling6 object."""
- super().__init__()
- self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, odim, 3, 2),
- torch.nn.ReLU(),
- torch.nn.Conv2d(odim, odim, 5, 3),
- torch.nn.ReLU(),
- )
- self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
- odim)
- self.pos_enc = pos_enc_class
- # 10 = (3 - 1) * 1 + (5 - 1) * 2
- self.subsampling_rate = 6
- self.right_context = 10
- def forward(
- self,
- x: torch.Tensor,
- x_mask: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Subsample x.
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
- Returns:
- torch.Tensor: Subsampled tensor (#batch, time', odim),
- where time' = time // 6.
- torch.Tensor: Subsampled mask (#batch, 1, time'),
- where time' = time // 6.
- torch.Tensor: positional encoding
- """
- x = x.unsqueeze(1) # (b, c, t, f)
- x = self.conv(x)
- b, c, t, f = x.size()
- x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
- x, pos_emb = self.pos_enc(x, offset)
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
- class Conv2dSubsampling8(BaseSubsampling):
- """Convolutional 2D subsampling (to 1/8 length).
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, idim: int, odim: int, dropout_rate: float,
- pos_enc_class: torch.nn.Module):
- """Construct an Conv2dSubsampling8 object."""
- super().__init__()
- self.conv = torch.nn.Sequential(
- torch.nn.Conv2d(1, odim, 3, 2),
- torch.nn.ReLU(),
- torch.nn.Conv2d(odim, odim, 3, 2),
- torch.nn.ReLU(),
- torch.nn.Conv2d(odim, odim, 3, 2),
- torch.nn.ReLU(),
- )
- self.linear = torch.nn.Linear(
- odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
- self.pos_enc = pos_enc_class
- self.subsampling_rate = 8
- # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
- self.right_context = 14
- def forward(
- self,
- x: torch.Tensor,
- x_mask: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Subsample x.
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
- Returns:
- torch.Tensor: Subsampled tensor (#batch, time', odim),
- where time' = time // 8.
- torch.Tensor: Subsampled mask (#batch, 1, time'),
- where time' = time // 8.
- torch.Tensor: positional encoding
- """
- x = x.unsqueeze(1) # (b, c, t, f)
- x = self.conv(x)
- b, c, t, f = x.size()
- x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
- x, pos_emb = self.pos_enc(x, offset)
- return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
- class LegacyLinearNoSubsampling(BaseSubsampling):
- """Linear transform the input without subsampling
- Args:
- idim (int): Input dimension.
- odim (int): Output dimension.
- dropout_rate (float): Dropout rate.
- """
- def __init__(self, idim: int, odim: int, dropout_rate: float,
- pos_enc_class: torch.nn.Module):
- """Construct an linear object."""
- super().__init__()
- self.out = torch.nn.Sequential(
- torch.nn.Linear(idim, odim),
- torch.nn.LayerNorm(odim, eps=1e-5),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- )
- self.pos_enc = pos_enc_class
- self.right_context = 0
- self.subsampling_rate = 1
- def forward(
- self,
- x: torch.Tensor,
- x_mask: torch.Tensor,
- offset: Union[int, torch.Tensor] = 0
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Input x.
- Args:
- x (torch.Tensor): Input tensor (#batch, time, idim).
- x_mask (torch.Tensor): Input mask (#batch, 1, time).
- Returns:
- torch.Tensor: linear input tensor (#batch, time', odim),
- where time' = time .
- torch.Tensor: linear input mask (#batch, 1, time'),
- where time' = time .
- """
- x = self.out(x)
- x, pos_emb = self.pos_enc(x, offset)
- return x, pos_emb, x_mask
|