| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421 |
- # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
- # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
- # 2024 Alibaba Inc (Xiang Lyu)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Modified from ESPnet(https://github.com/espnet/espnet)
- """Encoder definition."""
- from typing import Tuple
- import torch
- from torch import nn
- from torch.nn import functional as F
- from cosyvoice.transformer.convolution import ConvolutionModule
- from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
- from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
- from cosyvoice.utils.class_utils import (
- COSYVOICE_EMB_CLASSES,
- COSYVOICE_SUBSAMPLE_CLASSES,
- COSYVOICE_ATTENTION_CLASSES,
- COSYVOICE_ACTIVATION_CLASSES,
- )
- from cosyvoice.utils.mask import make_pad_mask
- from cosyvoice.utils.mask import add_optional_chunk_mask
- class Upsample1D(nn.Module):
- """A 1D upsampling layer with an optional convolution.
- Parameters:
- channels (`int`):
- number of channels in the inputs and outputs.
- use_conv (`bool`, default `False`):
- option to use a convolution.
- use_conv_transpose (`bool`, default `False`):
- option to use a convolution transpose.
- out_channels (`int`, optional):
- number of output channels. Defaults to `channels`.
- """
- def __init__(self, channels: int, out_channels: int, stride: int = 2):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels
- self.stride = stride
- # In this mode, first repeat interpolate, than conv with stride=1
- self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
- def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
- if conv_cache.size(2) == 0:
- outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
- else:
- assert conv_cache.size(2) == self.stride * 2
- outputs = torch.concat([conv_cache, outputs], dim=2)
- conv_cache_new = outputs[:, :, -self.stride * 2:]
- outputs = self.conv(outputs)
- return outputs, input_lengths * self.stride, conv_cache_new
- class PreLookaheadLayer(nn.Module):
- def __init__(self, channels: int, pre_lookahead_len: int = 1):
- super().__init__()
- self.channels = channels
- self.pre_lookahead_len = pre_lookahead_len
- self.conv1 = nn.Conv1d(
- channels, channels,
- kernel_size=pre_lookahead_len + 1,
- stride=1, padding=0,
- )
- self.conv2 = nn.Conv1d(
- channels, channels,
- kernel_size=3, stride=1, padding=0,
- )
- def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- inputs: (batch_size, seq_len, channels)
- """
- outputs = inputs.transpose(1, 2).contiguous()
- context = context.transpose(1, 2).contiguous()
- # look ahead
- if context.size(2) == 0:
- outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
- else:
- assert context.size(2) == self.pre_lookahead_len
- outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
- outputs = F.leaky_relu(self.conv1(outputs))
- # outputs
- if conv2_cache.size(2) == 0:
- outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
- else:
- assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
- outputs = torch.concat([conv2_cache, outputs], dim=2)
- conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
- outputs = self.conv2(outputs)
- outputs = outputs.transpose(1, 2).contiguous()
- # residual connection
- outputs = outputs + inputs
- return outputs, conv2_cache_new
- class UpsampleConformerEncoder(torch.nn.Module):
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: str = "conv2d",
- pos_enc_layer_type: str = "rel_pos",
- normalize_before: bool = True,
- static_chunk_size: int = 0,
- use_dynamic_chunk: bool = False,
- global_cmvn: torch.nn.Module = None,
- use_dynamic_left_chunk: bool = False,
- positionwise_conv_kernel_size: int = 1,
- macaron_style: bool = True,
- selfattention_layer_type: str = "rel_selfattn",
- activation_type: str = "swish",
- use_cnn_module: bool = True,
- cnn_module_kernel: int = 15,
- causal: bool = False,
- cnn_module_norm: str = "batch_norm",
- key_bias: bool = True,
- gradient_checkpointing: bool = False,
- ):
- """
- Args:
- input_size (int): input dim
- output_size (int): dimension of attention
- attention_heads (int): the number of heads of multi head attention
- linear_units (int): the hidden units number of position-wise feed
- forward
- num_blocks (int): the number of decoder blocks
- dropout_rate (float): dropout rate
- attention_dropout_rate (float): dropout rate in attention
- positional_dropout_rate (float): dropout rate after adding
- positional encoding
- input_layer (str): input layer type.
- optional [linear, conv2d, conv2d6, conv2d8]
- pos_enc_layer_type (str): Encoder positional encoding layer type.
- opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
- normalize_before (bool):
- True: use layer_norm before each sub-block of a layer.
- False: use layer_norm after each sub-block of a layer.
- static_chunk_size (int): chunk size for static chunk training and
- decoding
- use_dynamic_chunk (bool): whether use dynamic chunk size for
- training or not, You can only use fixed chunk(chunk_size > 0)
- or dyanmic chunk size(use_dynamic_chunk = True)
- global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
- use_dynamic_left_chunk (bool): whether use dynamic left chunk in
- dynamic chunk training
- key_bias: whether use bias in attention.linear_k, False for whisper models.
- gradient_checkpointing: rerunning a forward-pass segment for each
- checkpointed segment during backward.
- """
- super().__init__()
- self._output_size = output_size
- self.global_cmvn = global_cmvn
- self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
- input_size,
- output_size,
- dropout_rate,
- COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
- positional_dropout_rate),
- )
- self.normalize_before = normalize_before
- self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
- self.static_chunk_size = static_chunk_size
- self.use_dynamic_chunk = use_dynamic_chunk
- self.use_dynamic_left_chunk = use_dynamic_left_chunk
- self.gradient_checkpointing = gradient_checkpointing
- activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
- # self-attention module definition
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- key_bias,
- )
- # feed-forward module definition
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- activation,
- )
- # convolution module definition
- convolution_layer_args = (output_size, cnn_module_kernel, activation,
- cnn_module_norm, causal)
- self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
- self.encoders = torch.nn.ModuleList([
- ConformerEncoderLayer(
- output_size,
- COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
- *encoder_selfattn_layer_args),
- PositionwiseFeedForward(*positionwise_layer_args),
- PositionwiseFeedForward(
- *positionwise_layer_args) if macaron_style else None,
- ConvolutionModule(
- *convolution_layer_args) if use_cnn_module else None,
- dropout_rate,
- normalize_before,
- ) for _ in range(num_blocks)
- ])
- self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
- self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
- input_size,
- output_size,
- dropout_rate,
- COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
- positional_dropout_rate),
- )
- self.up_encoders = torch.nn.ModuleList([
- ConformerEncoderLayer(
- output_size,
- COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
- *encoder_selfattn_layer_args),
- PositionwiseFeedForward(*positionwise_layer_args),
- PositionwiseFeedForward(
- *positionwise_layer_args) if macaron_style else None,
- ConvolutionModule(
- *convolution_layer_args) if use_cnn_module else None,
- dropout_rate,
- normalize_before,
- ) for _ in range(4)
- ])
- def output_size(self) -> int:
- return self._output_size
- def forward(
- self,
- xs: torch.Tensor,
- xs_lens: torch.Tensor,
- decoding_chunk_size: int = 0,
- num_decoding_left_chunks: int = -1,
- streaming: bool = False,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Embed positions in tensor.
- Args:
- xs: padded input tensor (B, T, D)
- xs_lens: input length (B)
- decoding_chunk_size: decoding chunk size for dynamic chunk
- 0: default for training, use random dynamic chunk.
- <0: for decoding, use full chunk.
- >0: for decoding, use fixed chunk size as set.
- num_decoding_left_chunks: number of left chunks, this is for decoding,
- the chunk size is decoding_chunk_size.
- >=0: use num_decoding_left_chunks
- <0: use all left chunks
- Returns:
- encoder output tensor xs, and subsampled masks
- xs: padded output tensor (B, T' ~= T/subsample_rate, D)
- masks: torch.Tensor batch padding mask after subsample
- (B, 1, T' ~= T/subsample_rate)
- NOTE(xcsong):
- We pass the `__call__` method of the modules instead of `forward` to the
- checkpointing API because `__call__` attaches all the hooks of the module.
- https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
- """
- T = xs.size(1)
- masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
- if self.global_cmvn is not None:
- xs = self.global_cmvn(xs)
- xs, pos_emb, masks = self.embed(xs, masks)
- mask_pad = masks # (B, 1, T/subsample_rate)
- chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
- # lookahead + conformer encoder
- xs, _ = self.pre_lookahead_layer(xs)
- xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
- # upsample + conformer encoder
- xs = xs.transpose(1, 2).contiguous()
- xs, xs_lens, _ = self.up_layer(xs, xs_lens)
- xs = xs.transpose(1, 2).contiguous()
- T = xs.size(1)
- masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
- xs, pos_emb, masks = self.up_embed(xs, masks)
- mask_pad = masks # (B, 1, T/subsample_rate)
- chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
- xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
- if self.normalize_before:
- xs = self.after_norm(xs)
- # Here we assume the mask is not changed in encoder layers, so just
- # return the masks before encoder layers, and the masks will be used
- # for cross attention with decoder later
- return xs, masks
- def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
- pos_emb: torch.Tensor,
- mask_pad: torch.Tensor) -> torch.Tensor:
- for layer in self.encoders:
- xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
- return xs
- def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
- pos_emb: torch.Tensor,
- mask_pad: torch.Tensor) -> torch.Tensor:
- for layer in self.up_encoders:
- xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
- return xs
- @torch.jit.export
- def forward_chunk(
- self,
- xs: torch.Tensor,
- xs_lens: torch.Tensor,
- offset: int = 0,
- context: torch.Tensor = torch.zeros(0, 0, 0),
- pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
- encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
- upsample_offset: int = 0,
- upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
- upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
- ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
- """Embed positions in tensor.
- Args:
- xs: padded input tensor (B, T, D)
- xs_lens: input length (B)
- decoding_chunk_size: decoding chunk size for dynamic chunk
- 0: default for training, use random dynamic chunk.
- <0: for decoding, use full chunk.
- >0: for decoding, use fixed chunk size as set.
- num_decoding_left_chunks: number of left chunks, this is for decoding,
- the chunk size is decoding_chunk_size.
- >=0: use num_decoding_left_chunks
- <0: use all left chunks
- Returns:
- encoder output tensor xs, and subsampled masks
- xs: padded output tensor (B, T' ~= T/subsample_rate, D)
- masks: torch.Tensor batch padding mask after subsample
- (B, 1, T' ~= T/subsample_rate)
- NOTE(xcsong):
- We pass the `__call__` method of the modules instead of `forward` to the
- checkpointing API because `__call__` attaches all the hooks of the module.
- https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
- """
- assert xs.size(0) == 1
- # tmp_masks is just for interface compatibility
- tmp_masks = torch.ones(1,
- xs.size(1),
- device=xs.device,
- dtype=torch.bool)
- tmp_masks = tmp_masks.unsqueeze(1)
- if self.global_cmvn is not None:
- xs = self.global_cmvn(xs)
- # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
- xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
- offset += xs.size(1)
- tmp_masks = torch.ones(1,
- context.size(1),
- device=context.device,
- dtype=torch.bool)
- tmp_masks = tmp_masks.unsqueeze(1)
- if context.size(1) != 0:
- context, _, _ = self.embed(context, tmp_masks, offset)
- # lookahead + conformer encoder
- xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
- # NOTE in cache mode we do not need to call add_optional_chunk_mask
- chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
- mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
- encoders_kv_cache_list = []
- for index, layer in enumerate(self.encoders):
- xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
- encoders_kv_cache_list.append(encoders_kv_cache_new)
- encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
- # upsample
- xs = xs.transpose(1, 2).contiguous()
- xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
- xs = xs.transpose(1, 2).contiguous()
- # tmp_masks is just for interface compatibility
- tmp_masks = torch.ones(1,
- xs.size(1),
- device=xs.device,
- dtype=torch.bool)
- tmp_masks = tmp_masks.unsqueeze(1)
- xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
- upsample_offset += xs.size(1)
- # conformer encoder
- chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
- mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
- upsample_kv_cache_list = []
- for index, layer in enumerate(self.up_encoders):
- xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
- upsample_kv_cache_list.append(upsample_kv_cache_new)
- upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
- if self.normalize_before:
- xs = self.after_norm(xs)
- # Here we assume the mask is not changed in encoder layers, so just
- # return the masks before encoder layers, and the masks will be used
- # for cross attention with decoder later
- return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)
|