upsample_encoder.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
  3. # 2024 Alibaba Inc (Xiang Lyu)
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # Modified from ESPnet(https://github.com/espnet/espnet)
  17. """Encoder definition."""
  18. from typing import Tuple
  19. import torch
  20. from torch import nn
  21. import torch.utils.checkpoint as ckpt
  22. from torch.nn import functional as F
  23. from cosyvoice.transformer.convolution import ConvolutionModule
  24. from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
  25. from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
  26. from cosyvoice.utils.class_utils import (
  27. COSYVOICE_EMB_CLASSES,
  28. COSYVOICE_SUBSAMPLE_CLASSES,
  29. COSYVOICE_ATTENTION_CLASSES,
  30. COSYVOICE_ACTIVATION_CLASSES,
  31. )
  32. from cosyvoice.utils.mask import make_pad_mask
  33. from cosyvoice.utils.mask import add_optional_chunk_mask
  34. class Upsample1D(nn.Module):
  35. """A 1D upsampling layer with an optional convolution.
  36. Parameters:
  37. channels (`int`):
  38. number of channels in the inputs and outputs.
  39. use_conv (`bool`, default `False`):
  40. option to use a convolution.
  41. use_conv_transpose (`bool`, default `False`):
  42. option to use a convolution transpose.
  43. out_channels (`int`, optional):
  44. number of output channels. Defaults to `channels`.
  45. """
  46. def __init__(self, channels: int, out_channels: int, stride: int=2):
  47. super().__init__()
  48. self.channels = channels
  49. self.out_channels = out_channels
  50. self.stride = stride
  51. # In this mode, first repeat interpolate, than conv with stride=1
  52. self.conv = nn.Conv1d(
  53. self.channels, self.out_channels, stride*2+1, stride=1,
  54. padding=0,
  55. )
  56. def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
  57. outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
  58. outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
  59. outputs = self.conv(outputs)
  60. return outputs, input_lengths * self.stride
  61. class PreLookaheadLayer(nn.Module):
  62. def __init__(self, channels: int, pre_lookahead_len: int = 1):
  63. super().__init__()
  64. self.channels = channels
  65. self.pre_lookahead_len = pre_lookahead_len
  66. self.conv1 = nn.Conv1d(
  67. channels, channels,
  68. kernel_size=pre_lookahead_len+1,
  69. stride=1, padding=0,
  70. )
  71. self.conv2 = nn.Conv1d(
  72. channels, channels,
  73. kernel_size=3, stride=1, padding=0,
  74. )
  75. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  76. """
  77. inputs: (batch_size, seq_len, channels)
  78. """
  79. outputs = inputs.transpose(1, 2).contiguous()
  80. # look ahead
  81. outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
  82. outputs = F.leaky_relu(self.conv1(outputs))
  83. # outputs
  84. outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
  85. outputs = self.conv2(outputs)
  86. outputs = outputs.transpose(1, 2).contiguous()
  87. # residual connection
  88. outputs = outputs + inputs
  89. return outputs
  90. class UpsampleConformerEncoder(torch.nn.Module):
  91. def __init__(
  92. self,
  93. input_size: int,
  94. output_size: int = 256,
  95. attention_heads: int = 4,
  96. linear_units: int = 2048,
  97. num_blocks: int = 6,
  98. dropout_rate: float = 0.1,
  99. positional_dropout_rate: float = 0.1,
  100. attention_dropout_rate: float = 0.0,
  101. input_layer: str = "conv2d",
  102. pos_enc_layer_type: str = "rel_pos",
  103. normalize_before: bool = True,
  104. static_chunk_size: int = 0,
  105. use_dynamic_chunk: bool = False,
  106. global_cmvn: torch.nn.Module = None,
  107. use_dynamic_left_chunk: bool = False,
  108. positionwise_conv_kernel_size: int = 1,
  109. macaron_style: bool = True,
  110. selfattention_layer_type: str = "rel_selfattn",
  111. activation_type: str = "swish",
  112. use_cnn_module: bool = True,
  113. cnn_module_kernel: int = 15,
  114. causal: bool = False,
  115. cnn_module_norm: str = "batch_norm",
  116. key_bias: bool = True,
  117. gradient_checkpointing: bool = False,
  118. ):
  119. """
  120. Args:
  121. input_size (int): input dim
  122. output_size (int): dimension of attention
  123. attention_heads (int): the number of heads of multi head attention
  124. linear_units (int): the hidden units number of position-wise feed
  125. forward
  126. num_blocks (int): the number of decoder blocks
  127. dropout_rate (float): dropout rate
  128. attention_dropout_rate (float): dropout rate in attention
  129. positional_dropout_rate (float): dropout rate after adding
  130. positional encoding
  131. input_layer (str): input layer type.
  132. optional [linear, conv2d, conv2d6, conv2d8]
  133. pos_enc_layer_type (str): Encoder positional encoding layer type.
  134. opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
  135. normalize_before (bool):
  136. True: use layer_norm before each sub-block of a layer.
  137. False: use layer_norm after each sub-block of a layer.
  138. static_chunk_size (int): chunk size for static chunk training and
  139. decoding
  140. use_dynamic_chunk (bool): whether use dynamic chunk size for
  141. training or not, You can only use fixed chunk(chunk_size > 0)
  142. or dyanmic chunk size(use_dynamic_chunk = True)
  143. global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
  144. use_dynamic_left_chunk (bool): whether use dynamic left chunk in
  145. dynamic chunk training
  146. key_bias: whether use bias in attention.linear_k, False for whisper models.
  147. gradient_checkpointing: rerunning a forward-pass segment for each
  148. checkpointed segment during backward.
  149. """
  150. super().__init__()
  151. self._output_size = output_size
  152. self.global_cmvn = global_cmvn
  153. self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
  154. input_size,
  155. output_size,
  156. dropout_rate,
  157. COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
  158. positional_dropout_rate),
  159. )
  160. self.normalize_before = normalize_before
  161. self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
  162. self.static_chunk_size = static_chunk_size
  163. self.use_dynamic_chunk = use_dynamic_chunk
  164. self.use_dynamic_left_chunk = use_dynamic_left_chunk
  165. self.gradient_checkpointing = gradient_checkpointing
  166. activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
  167. # self-attention module definition
  168. encoder_selfattn_layer_args = (
  169. attention_heads,
  170. output_size,
  171. attention_dropout_rate,
  172. key_bias,
  173. )
  174. # feed-forward module definition
  175. positionwise_layer_args = (
  176. output_size,
  177. linear_units,
  178. dropout_rate,
  179. activation,
  180. )
  181. # convolution module definition
  182. convolution_layer_args = (output_size, cnn_module_kernel, activation,
  183. cnn_module_norm, causal)
  184. self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
  185. self.encoders = torch.nn.ModuleList([
  186. ConformerEncoderLayer(
  187. output_size,
  188. COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
  189. *encoder_selfattn_layer_args),
  190. PositionwiseFeedForward(*positionwise_layer_args),
  191. PositionwiseFeedForward(
  192. *positionwise_layer_args) if macaron_style else None,
  193. ConvolutionModule(
  194. *convolution_layer_args) if use_cnn_module else None,
  195. dropout_rate,
  196. normalize_before,
  197. ) for _ in range(num_blocks)
  198. ])
  199. self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
  200. self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
  201. input_size,
  202. output_size,
  203. dropout_rate,
  204. COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
  205. positional_dropout_rate),
  206. )
  207. self.up_encoders = torch.nn.ModuleList([
  208. ConformerEncoderLayer(
  209. output_size,
  210. COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
  211. *encoder_selfattn_layer_args),
  212. PositionwiseFeedForward(*positionwise_layer_args),
  213. PositionwiseFeedForward(
  214. *positionwise_layer_args) if macaron_style else None,
  215. ConvolutionModule(
  216. *convolution_layer_args) if use_cnn_module else None,
  217. dropout_rate,
  218. normalize_before,
  219. ) for _ in range(4)
  220. ])
  221. def output_size(self) -> int:
  222. return self._output_size
  223. def forward(
  224. self,
  225. xs: torch.Tensor,
  226. xs_lens: torch.Tensor,
  227. decoding_chunk_size: int = 0,
  228. num_decoding_left_chunks: int = -1,
  229. ) -> Tuple[torch.Tensor, torch.Tensor]:
  230. """Embed positions in tensor.
  231. Args:
  232. xs: padded input tensor (B, T, D)
  233. xs_lens: input length (B)
  234. decoding_chunk_size: decoding chunk size for dynamic chunk
  235. 0: default for training, use random dynamic chunk.
  236. <0: for decoding, use full chunk.
  237. >0: for decoding, use fixed chunk size as set.
  238. num_decoding_left_chunks: number of left chunks, this is for decoding,
  239. the chunk size is decoding_chunk_size.
  240. >=0: use num_decoding_left_chunks
  241. <0: use all left chunks
  242. Returns:
  243. encoder output tensor xs, and subsampled masks
  244. xs: padded output tensor (B, T' ~= T/subsample_rate, D)
  245. masks: torch.Tensor batch padding mask after subsample
  246. (B, 1, T' ~= T/subsample_rate)
  247. NOTE(xcsong):
  248. We pass the `__call__` method of the modules instead of `forward` to the
  249. checkpointing API because `__call__` attaches all the hooks of the module.
  250. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  251. """
  252. T = xs.size(1)
  253. masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
  254. if self.global_cmvn is not None:
  255. xs = self.global_cmvn(xs)
  256. xs, pos_emb, masks = self.embed(xs, masks)
  257. mask_pad = masks # (B, 1, T/subsample_rate)
  258. chunk_masks = add_optional_chunk_mask(xs, masks,
  259. self.use_dynamic_chunk,
  260. self.use_dynamic_left_chunk,
  261. decoding_chunk_size,
  262. self.static_chunk_size,
  263. num_decoding_left_chunks)
  264. # lookahead + conformer encoder
  265. xs = self.pre_lookahead_layer(xs)
  266. xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
  267. # upsample + conformer encoder
  268. xs = xs.transpose(1, 2).contiguous()
  269. xs, xs_lens = self.up_layer(xs, xs_lens)
  270. xs = xs.transpose(1, 2).contiguous()
  271. T = xs.size(1)
  272. masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
  273. xs, pos_emb, masks = self.up_embed(xs, masks)
  274. mask_pad = masks # (B, 1, T/subsample_rate)
  275. chunk_masks = add_optional_chunk_mask(xs, masks,
  276. self.use_dynamic_chunk,
  277. self.use_dynamic_left_chunk,
  278. decoding_chunk_size,
  279. self.static_chunk_size * self.up_layer.stride,
  280. num_decoding_left_chunks)
  281. xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
  282. if self.normalize_before:
  283. xs = self.after_norm(xs)
  284. # Here we assume the mask is not changed in encoder layers, so just
  285. # return the masks before encoder layers, and the masks will be used
  286. # for cross attention with decoder later
  287. return xs, masks
  288. def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
  289. pos_emb: torch.Tensor,
  290. mask_pad: torch.Tensor) -> torch.Tensor:
  291. for layer in self.encoders:
  292. xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
  293. return xs
  294. def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
  295. pos_emb: torch.Tensor,
  296. mask_pad: torch.Tensor) -> torch.Tensor:
  297. for layer in self.up_encoders:
  298. xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
  299. return xs