upsample_encoder.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
  3. # 2024 Alibaba Inc (Xiang Lyu)
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # Modified from ESPnet(https://github.com/espnet/espnet)
  17. """Encoder definition."""
  18. from typing import Tuple
  19. import torch
  20. from torch import nn
  21. from torch.nn import functional as F
  22. from cosyvoice.transformer.convolution import ConvolutionModule
  23. from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
  24. from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
  25. from cosyvoice.utils.class_utils import (
  26. COSYVOICE_EMB_CLASSES,
  27. COSYVOICE_SUBSAMPLE_CLASSES,
  28. COSYVOICE_ATTENTION_CLASSES,
  29. COSYVOICE_ACTIVATION_CLASSES,
  30. )
  31. from cosyvoice.utils.mask import make_pad_mask
  32. from cosyvoice.utils.mask import add_optional_chunk_mask
  33. class Upsample1D(nn.Module):
  34. """A 1D upsampling layer with an optional convolution.
  35. Parameters:
  36. channels (`int`):
  37. number of channels in the inputs and outputs.
  38. use_conv (`bool`, default `False`):
  39. option to use a convolution.
  40. use_conv_transpose (`bool`, default `False`):
  41. option to use a convolution transpose.
  42. out_channels (`int`, optional):
  43. number of output channels. Defaults to `channels`.
  44. """
  45. def __init__(self, channels: int, out_channels: int, stride: int = 2):
  46. super().__init__()
  47. self.channels = channels
  48. self.out_channels = out_channels
  49. self.stride = stride
  50. # In this mode, first repeat interpolate, than conv with stride=1
  51. self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
  52. def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  53. outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
  54. if conv_cache.size(2) == 0:
  55. outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
  56. else:
  57. assert conv_cache.size(2) == self.stride * 2
  58. outputs = torch.concat([conv_cache, outputs], dim=2)
  59. conv_cache_new = outputs[:, :, -self.stride * 2:]
  60. outputs = self.conv(outputs)
  61. return outputs, input_lengths * self.stride, conv_cache_new
  62. class PreLookaheadLayer(nn.Module):
  63. def __init__(self, channels: int, pre_lookahead_len: int = 1):
  64. super().__init__()
  65. self.channels = channels
  66. self.pre_lookahead_len = pre_lookahead_len
  67. self.conv1 = nn.Conv1d(
  68. channels, channels,
  69. kernel_size=pre_lookahead_len + 1,
  70. stride=1, padding=0,
  71. )
  72. self.conv2 = nn.Conv1d(
  73. channels, channels,
  74. kernel_size=3, stride=1, padding=0,
  75. )
  76. def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  77. """
  78. inputs: (batch_size, seq_len, channels)
  79. """
  80. outputs = inputs.transpose(1, 2).contiguous()
  81. context = context.transpose(1, 2).contiguous()
  82. # look ahead
  83. if context.size(2) == 0:
  84. outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
  85. else:
  86. assert context.size(2) == self.pre_lookahead_len
  87. outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
  88. outputs = F.leaky_relu(self.conv1(outputs))
  89. # outputs
  90. if conv2_cache.size(2) == 0:
  91. outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
  92. else:
  93. assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
  94. outputs = torch.concat([conv2_cache, outputs], dim=2)
  95. conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
  96. outputs = self.conv2(outputs)
  97. outputs = outputs.transpose(1, 2).contiguous()
  98. # residual connection
  99. outputs = outputs + inputs
  100. return outputs, conv2_cache_new
  101. class UpsampleConformerEncoder(torch.nn.Module):
  102. def __init__(
  103. self,
  104. input_size: int,
  105. output_size: int = 256,
  106. attention_heads: int = 4,
  107. linear_units: int = 2048,
  108. num_blocks: int = 6,
  109. dropout_rate: float = 0.1,
  110. positional_dropout_rate: float = 0.1,
  111. attention_dropout_rate: float = 0.0,
  112. input_layer: str = "conv2d",
  113. pos_enc_layer_type: str = "rel_pos",
  114. normalize_before: bool = True,
  115. static_chunk_size: int = 0,
  116. use_dynamic_chunk: bool = False,
  117. global_cmvn: torch.nn.Module = None,
  118. use_dynamic_left_chunk: bool = False,
  119. positionwise_conv_kernel_size: int = 1,
  120. macaron_style: bool = True,
  121. selfattention_layer_type: str = "rel_selfattn",
  122. activation_type: str = "swish",
  123. use_cnn_module: bool = True,
  124. cnn_module_kernel: int = 15,
  125. causal: bool = False,
  126. cnn_module_norm: str = "batch_norm",
  127. key_bias: bool = True,
  128. gradient_checkpointing: bool = False,
  129. ):
  130. """
  131. Args:
  132. input_size (int): input dim
  133. output_size (int): dimension of attention
  134. attention_heads (int): the number of heads of multi head attention
  135. linear_units (int): the hidden units number of position-wise feed
  136. forward
  137. num_blocks (int): the number of decoder blocks
  138. dropout_rate (float): dropout rate
  139. attention_dropout_rate (float): dropout rate in attention
  140. positional_dropout_rate (float): dropout rate after adding
  141. positional encoding
  142. input_layer (str): input layer type.
  143. optional [linear, conv2d, conv2d6, conv2d8]
  144. pos_enc_layer_type (str): Encoder positional encoding layer type.
  145. opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
  146. normalize_before (bool):
  147. True: use layer_norm before each sub-block of a layer.
  148. False: use layer_norm after each sub-block of a layer.
  149. static_chunk_size (int): chunk size for static chunk training and
  150. decoding
  151. use_dynamic_chunk (bool): whether use dynamic chunk size for
  152. training or not, You can only use fixed chunk(chunk_size > 0)
  153. or dyanmic chunk size(use_dynamic_chunk = True)
  154. global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
  155. use_dynamic_left_chunk (bool): whether use dynamic left chunk in
  156. dynamic chunk training
  157. key_bias: whether use bias in attention.linear_k, False for whisper models.
  158. gradient_checkpointing: rerunning a forward-pass segment for each
  159. checkpointed segment during backward.
  160. """
  161. super().__init__()
  162. self._output_size = output_size
  163. self.global_cmvn = global_cmvn
  164. self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
  165. input_size,
  166. output_size,
  167. dropout_rate,
  168. COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
  169. positional_dropout_rate),
  170. )
  171. self.normalize_before = normalize_before
  172. self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
  173. self.static_chunk_size = static_chunk_size
  174. self.use_dynamic_chunk = use_dynamic_chunk
  175. self.use_dynamic_left_chunk = use_dynamic_left_chunk
  176. self.gradient_checkpointing = gradient_checkpointing
  177. activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
  178. # self-attention module definition
  179. encoder_selfattn_layer_args = (
  180. attention_heads,
  181. output_size,
  182. attention_dropout_rate,
  183. key_bias,
  184. )
  185. # feed-forward module definition
  186. positionwise_layer_args = (
  187. output_size,
  188. linear_units,
  189. dropout_rate,
  190. activation,
  191. )
  192. # convolution module definition
  193. convolution_layer_args = (output_size, cnn_module_kernel, activation,
  194. cnn_module_norm, causal)
  195. self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
  196. self.encoders = torch.nn.ModuleList([
  197. ConformerEncoderLayer(
  198. output_size,
  199. COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
  200. *encoder_selfattn_layer_args),
  201. PositionwiseFeedForward(*positionwise_layer_args),
  202. PositionwiseFeedForward(
  203. *positionwise_layer_args) if macaron_style else None,
  204. ConvolutionModule(
  205. *convolution_layer_args) if use_cnn_module else None,
  206. dropout_rate,
  207. normalize_before,
  208. ) for _ in range(num_blocks)
  209. ])
  210. self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
  211. self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
  212. input_size,
  213. output_size,
  214. dropout_rate,
  215. COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
  216. positional_dropout_rate),
  217. )
  218. self.up_encoders = torch.nn.ModuleList([
  219. ConformerEncoderLayer(
  220. output_size,
  221. COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
  222. *encoder_selfattn_layer_args),
  223. PositionwiseFeedForward(*positionwise_layer_args),
  224. PositionwiseFeedForward(
  225. *positionwise_layer_args) if macaron_style else None,
  226. ConvolutionModule(
  227. *convolution_layer_args) if use_cnn_module else None,
  228. dropout_rate,
  229. normalize_before,
  230. ) for _ in range(4)
  231. ])
  232. def output_size(self) -> int:
  233. return self._output_size
  234. def forward(
  235. self,
  236. xs: torch.Tensor,
  237. xs_lens: torch.Tensor,
  238. decoding_chunk_size: int = 0,
  239. num_decoding_left_chunks: int = -1,
  240. streaming: bool = False,
  241. ) -> Tuple[torch.Tensor, torch.Tensor]:
  242. """Embed positions in tensor.
  243. Args:
  244. xs: padded input tensor (B, T, D)
  245. xs_lens: input length (B)
  246. decoding_chunk_size: decoding chunk size for dynamic chunk
  247. 0: default for training, use random dynamic chunk.
  248. <0: for decoding, use full chunk.
  249. >0: for decoding, use fixed chunk size as set.
  250. num_decoding_left_chunks: number of left chunks, this is for decoding,
  251. the chunk size is decoding_chunk_size.
  252. >=0: use num_decoding_left_chunks
  253. <0: use all left chunks
  254. Returns:
  255. encoder output tensor xs, and subsampled masks
  256. xs: padded output tensor (B, T' ~= T/subsample_rate, D)
  257. masks: torch.Tensor batch padding mask after subsample
  258. (B, 1, T' ~= T/subsample_rate)
  259. NOTE(xcsong):
  260. We pass the `__call__` method of the modules instead of `forward` to the
  261. checkpointing API because `__call__` attaches all the hooks of the module.
  262. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  263. """
  264. T = xs.size(1)
  265. masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
  266. if self.global_cmvn is not None:
  267. xs = self.global_cmvn(xs)
  268. xs, pos_emb, masks = self.embed(xs, masks)
  269. mask_pad = masks # (B, 1, T/subsample_rate)
  270. chunk_masks = add_optional_chunk_mask(xs, masks,
  271. self.use_dynamic_chunk if streaming is True else False,
  272. self.use_dynamic_left_chunk if streaming is True else False,
  273. decoding_chunk_size if streaming is True else 0,
  274. self.static_chunk_size if streaming is True else 0,
  275. num_decoding_left_chunks if streaming is True else -1)
  276. # lookahead + conformer encoder
  277. xs, _ = self.pre_lookahead_layer(xs)
  278. xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
  279. # upsample + conformer encoder
  280. xs = xs.transpose(1, 2).contiguous()
  281. xs, xs_lens, _ = self.up_layer(xs, xs_lens)
  282. xs = xs.transpose(1, 2).contiguous()
  283. T = xs.size(1)
  284. masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
  285. xs, pos_emb, masks = self.up_embed(xs, masks)
  286. mask_pad = masks # (B, 1, T/subsample_rate)
  287. chunk_masks = add_optional_chunk_mask(xs, masks,
  288. self.use_dynamic_chunk if streaming is True else False,
  289. self.use_dynamic_left_chunk if streaming is True else False,
  290. decoding_chunk_size if streaming is True else 0,
  291. self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
  292. num_decoding_left_chunks if streaming is True else -1)
  293. xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
  294. if self.normalize_before:
  295. xs = self.after_norm(xs)
  296. # Here we assume the mask is not changed in encoder layers, so just
  297. # return the masks before encoder layers, and the masks will be used
  298. # for cross attention with decoder later
  299. return xs, masks
  300. def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
  301. pos_emb: torch.Tensor,
  302. mask_pad: torch.Tensor) -> torch.Tensor:
  303. for layer in self.encoders:
  304. xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
  305. return xs
  306. def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
  307. pos_emb: torch.Tensor,
  308. mask_pad: torch.Tensor) -> torch.Tensor:
  309. for layer in self.up_encoders:
  310. xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
  311. return xs
  312. @torch.jit.export
  313. def forward_chunk(
  314. self,
  315. xs: torch.Tensor,
  316. xs_lens: torch.Tensor,
  317. offset: int = 0,
  318. context: torch.Tensor = torch.zeros(0, 0, 0),
  319. pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
  320. encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
  321. upsample_offset: int = 0,
  322. upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
  323. upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
  324. ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
  325. """Embed positions in tensor.
  326. Args:
  327. xs: padded input tensor (B, T, D)
  328. xs_lens: input length (B)
  329. decoding_chunk_size: decoding chunk size for dynamic chunk
  330. 0: default for training, use random dynamic chunk.
  331. <0: for decoding, use full chunk.
  332. >0: for decoding, use fixed chunk size as set.
  333. num_decoding_left_chunks: number of left chunks, this is for decoding,
  334. the chunk size is decoding_chunk_size.
  335. >=0: use num_decoding_left_chunks
  336. <0: use all left chunks
  337. Returns:
  338. encoder output tensor xs, and subsampled masks
  339. xs: padded output tensor (B, T' ~= T/subsample_rate, D)
  340. masks: torch.Tensor batch padding mask after subsample
  341. (B, 1, T' ~= T/subsample_rate)
  342. NOTE(xcsong):
  343. We pass the `__call__` method of the modules instead of `forward` to the
  344. checkpointing API because `__call__` attaches all the hooks of the module.
  345. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  346. """
  347. assert xs.size(0) == 1
  348. # tmp_masks is just for interface compatibility
  349. tmp_masks = torch.ones(1,
  350. xs.size(1),
  351. device=xs.device,
  352. dtype=torch.bool)
  353. tmp_masks = tmp_masks.unsqueeze(1)
  354. if self.global_cmvn is not None:
  355. xs = self.global_cmvn(xs)
  356. # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
  357. xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
  358. offset += xs.size(1)
  359. tmp_masks = torch.ones(1,
  360. context.size(1),
  361. device=context.device,
  362. dtype=torch.bool)
  363. tmp_masks = tmp_masks.unsqueeze(1)
  364. if context.size(1) != 0:
  365. context, _, _ = self.embed(context, tmp_masks, offset)
  366. # lookahead + conformer encoder
  367. xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
  368. # NOTE in cache mode we do not need to call add_optional_chunk_mask
  369. chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
  370. mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
  371. encoders_kv_cache_list = []
  372. for index, layer in enumerate(self.encoders):
  373. xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
  374. encoders_kv_cache_list.append(encoders_kv_cache_new)
  375. encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
  376. # upsample
  377. xs = xs.transpose(1, 2).contiguous()
  378. xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
  379. xs = xs.transpose(1, 2).contiguous()
  380. # tmp_masks is just for interface compatibility
  381. tmp_masks = torch.ones(1,
  382. xs.size(1),
  383. device=xs.device,
  384. dtype=torch.bool)
  385. tmp_masks = tmp_masks.unsqueeze(1)
  386. xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
  387. upsample_offset += xs.size(1)
  388. # conformer encoder
  389. chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
  390. mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
  391. upsample_kv_cache_list = []
  392. for index, layer in enumerate(self.up_encoders):
  393. xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
  394. upsample_kv_cache_list.append(upsample_kv_cache_new)
  395. upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
  396. if self.normalize_before:
  397. xs = self.after_norm(xs)
  398. # Here we assume the mask is not changed in encoder layers, so just
  399. # return the masks before encoder layers, and the masks will be used
  400. # for cross attention with decoder later
  401. return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)