encoder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
  3. # 2024 Alibaba Inc (Xiang Lyu)
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # Modified from ESPnet(https://github.com/espnet/espnet)
  17. """Encoder definition."""
  18. from typing import Tuple
  19. import torch
  20. import torch.utils.checkpoint as ckpt
  21. from cosyvoice.transformer.convolution import ConvolutionModule
  22. from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
  23. from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
  24. from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
  25. from cosyvoice.utils.class_utils import (
  26. COSYVOICE_EMB_CLASSES,
  27. COSYVOICE_SUBSAMPLE_CLASSES,
  28. COSYVOICE_ATTENTION_CLASSES,
  29. COSYVOICE_ACTIVATION_CLASSES,
  30. )
  31. from cosyvoice.utils.mask import make_pad_mask
  32. from cosyvoice.utils.mask import add_optional_chunk_mask
  33. class BaseEncoder(torch.nn.Module):
  34. def __init__(
  35. self,
  36. input_size: int,
  37. output_size: int = 256,
  38. attention_heads: int = 4,
  39. linear_units: int = 2048,
  40. num_blocks: int = 6,
  41. dropout_rate: float = 0.1,
  42. positional_dropout_rate: float = 0.1,
  43. attention_dropout_rate: float = 0.0,
  44. input_layer: str = "conv2d",
  45. pos_enc_layer_type: str = "abs_pos",
  46. normalize_before: bool = True,
  47. static_chunk_size: int = 0,
  48. use_dynamic_chunk: bool = False,
  49. global_cmvn: torch.nn.Module = None,
  50. use_dynamic_left_chunk: bool = False,
  51. gradient_checkpointing: bool = False,
  52. ):
  53. """
  54. Args:
  55. input_size (int): input dim
  56. output_size (int): dimension of attention
  57. attention_heads (int): the number of heads of multi head attention
  58. linear_units (int): the hidden units number of position-wise feed
  59. forward
  60. num_blocks (int): the number of decoder blocks
  61. dropout_rate (float): dropout rate
  62. attention_dropout_rate (float): dropout rate in attention
  63. positional_dropout_rate (float): dropout rate after adding
  64. positional encoding
  65. input_layer (str): input layer type.
  66. optional [linear, conv2d, conv2d6, conv2d8]
  67. pos_enc_layer_type (str): Encoder positional encoding layer type.
  68. opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
  69. normalize_before (bool):
  70. True: use layer_norm before each sub-block of a layer.
  71. False: use layer_norm after each sub-block of a layer.
  72. static_chunk_size (int): chunk size for static chunk training and
  73. decoding
  74. use_dynamic_chunk (bool): whether use dynamic chunk size for
  75. training or not, You can only use fixed chunk(chunk_size > 0)
  76. or dyanmic chunk size(use_dynamic_chunk = True)
  77. global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
  78. use_dynamic_left_chunk (bool): whether use dynamic left chunk in
  79. dynamic chunk training
  80. key_bias: whether use bias in attention.linear_k, False for whisper models.
  81. gradient_checkpointing: rerunning a forward-pass segment for each
  82. checkpointed segment during backward.
  83. """
  84. super().__init__()
  85. self._output_size = output_size
  86. self.global_cmvn = global_cmvn
  87. self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
  88. input_size,
  89. output_size,
  90. dropout_rate,
  91. COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
  92. positional_dropout_rate),
  93. )
  94. self.normalize_before = normalize_before
  95. self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
  96. self.static_chunk_size = static_chunk_size
  97. self.use_dynamic_chunk = use_dynamic_chunk
  98. self.use_dynamic_left_chunk = use_dynamic_left_chunk
  99. self.gradient_checkpointing = gradient_checkpointing
  100. def output_size(self) -> int:
  101. return self._output_size
  102. def forward(
  103. self,
  104. xs: torch.Tensor,
  105. xs_lens: torch.Tensor,
  106. decoding_chunk_size: int = 0,
  107. num_decoding_left_chunks: int = -1,
  108. ) -> Tuple[torch.Tensor, torch.Tensor]:
  109. """Embed positions in tensor.
  110. Args:
  111. xs: padded input tensor (B, T, D)
  112. xs_lens: input length (B)
  113. decoding_chunk_size: decoding chunk size for dynamic chunk
  114. 0: default for training, use random dynamic chunk.
  115. <0: for decoding, use full chunk.
  116. >0: for decoding, use fixed chunk size as set.
  117. num_decoding_left_chunks: number of left chunks, this is for decoding,
  118. the chunk size is decoding_chunk_size.
  119. >=0: use num_decoding_left_chunks
  120. <0: use all left chunks
  121. Returns:
  122. encoder output tensor xs, and subsampled masks
  123. xs: padded output tensor (B, T' ~= T/subsample_rate, D)
  124. masks: torch.Tensor batch padding mask after subsample
  125. (B, 1, T' ~= T/subsample_rate)
  126. NOTE(xcsong):
  127. We pass the `__call__` method of the modules instead of `forward` to the
  128. checkpointing API because `__call__` attaches all the hooks of the module.
  129. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  130. """
  131. T = xs.size(1)
  132. masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
  133. if self.global_cmvn is not None:
  134. xs = self.global_cmvn(xs)
  135. xs, pos_emb, masks = self.embed(xs, masks)
  136. mask_pad = masks # (B, 1, T/subsample_rate)
  137. chunk_masks = add_optional_chunk_mask(xs, masks,
  138. self.use_dynamic_chunk,
  139. self.use_dynamic_left_chunk,
  140. decoding_chunk_size,
  141. self.static_chunk_size,
  142. num_decoding_left_chunks)
  143. if self.gradient_checkpointing and self.training:
  144. xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
  145. mask_pad)
  146. else:
  147. xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
  148. if self.normalize_before:
  149. xs = self.after_norm(xs)
  150. # Here we assume the mask is not changed in encoder layers, so just
  151. # return the masks before encoder layers, and the masks will be used
  152. # for cross attention with decoder later
  153. return xs, masks
  154. def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
  155. pos_emb: torch.Tensor,
  156. mask_pad: torch.Tensor) -> torch.Tensor:
  157. for layer in self.encoders:
  158. xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
  159. return xs
  160. @torch.jit.ignore(drop=True)
  161. def forward_layers_checkpointed(self, xs: torch.Tensor,
  162. chunk_masks: torch.Tensor,
  163. pos_emb: torch.Tensor,
  164. mask_pad: torch.Tensor) -> torch.Tensor:
  165. for layer in self.encoders:
  166. xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
  167. chunk_masks, pos_emb,
  168. mask_pad)
  169. return xs
  170. def forward_chunk(
  171. self,
  172. xs: torch.Tensor,
  173. offset: int,
  174. required_cache_size: int,
  175. att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  176. cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  177. att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  178. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  179. """ Forward just one chunk
  180. Args:
  181. xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
  182. where `time == (chunk_size - 1) * subsample_rate + \
  183. subsample.right_context + 1`
  184. offset (int): current offset in encoder output time stamp
  185. required_cache_size (int): cache size required for next chunk
  186. compuation
  187. >=0: actual cache size
  188. <0: means all history cache is required
  189. att_cache (torch.Tensor): cache tensor for KEY & VALUE in
  190. transformer/conformer attention, with shape
  191. (elayers, head, cache_t1, d_k * 2), where
  192. `head * d_k == hidden-dim` and
  193. `cache_t1 == chunk_size * num_decoding_left_chunks`.
  194. cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
  195. (elayers, b=1, hidden-dim, cache_t2), where
  196. `cache_t2 == cnn.lorder - 1`
  197. Returns:
  198. torch.Tensor: output of current input xs,
  199. with shape (b=1, chunk_size, hidden-dim).
  200. torch.Tensor: new attention cache required for next chunk, with
  201. dynamic shape (elayers, head, ?, d_k * 2)
  202. depending on required_cache_size.
  203. torch.Tensor: new conformer cnn cache required for next chunk, with
  204. same shape as the original cnn_cache.
  205. """
  206. assert xs.size(0) == 1
  207. # tmp_masks is just for interface compatibility
  208. tmp_masks = torch.ones(1,
  209. xs.size(1),
  210. device=xs.device,
  211. dtype=torch.bool)
  212. tmp_masks = tmp_masks.unsqueeze(1)
  213. if self.global_cmvn is not None:
  214. xs = self.global_cmvn(xs)
  215. # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
  216. xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
  217. # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
  218. elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
  219. chunk_size = xs.size(1)
  220. attention_key_size = cache_t1 + chunk_size
  221. pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
  222. size=attention_key_size)
  223. if required_cache_size < 0:
  224. next_cache_start = 0
  225. elif required_cache_size == 0:
  226. next_cache_start = attention_key_size
  227. else:
  228. next_cache_start = max(attention_key_size - required_cache_size, 0)
  229. r_att_cache = []
  230. r_cnn_cache = []
  231. for i, layer in enumerate(self.encoders):
  232. # NOTE(xcsong): Before layer.forward
  233. # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
  234. # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
  235. xs, _, new_att_cache, new_cnn_cache = layer(
  236. xs,
  237. att_mask,
  238. pos_emb,
  239. att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
  240. cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
  241. # NOTE(xcsong): After layer.forward
  242. # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
  243. # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
  244. r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
  245. r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
  246. if self.normalize_before:
  247. xs = self.after_norm(xs)
  248. # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
  249. # ? may be larger than cache_t1, it depends on required_cache_size
  250. r_att_cache = torch.cat(r_att_cache, dim=0)
  251. # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
  252. r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
  253. return (xs, r_att_cache, r_cnn_cache)
  254. def forward_chunk_by_chunk(
  255. self,
  256. xs: torch.Tensor,
  257. decoding_chunk_size: int,
  258. num_decoding_left_chunks: int = -1,
  259. ) -> Tuple[torch.Tensor, torch.Tensor]:
  260. """ Forward input chunk by chunk with chunk_size like a streaming
  261. fashion
  262. Here we should pay special attention to computation cache in the
  263. streaming style forward chunk by chunk. Three things should be taken
  264. into account for computation in the current network:
  265. 1. transformer/conformer encoder layers output cache
  266. 2. convolution in conformer
  267. 3. convolution in subsampling
  268. However, we don't implement subsampling cache for:
  269. 1. We can control subsampling module to output the right result by
  270. overlapping input instead of cache left context, even though it
  271. wastes some computation, but subsampling only takes a very
  272. small fraction of computation in the whole model.
  273. 2. Typically, there are several covolution layers with subsampling
  274. in subsampling module, it is tricky and complicated to do cache
  275. with different convolution layers with different subsampling
  276. rate.
  277. 3. Currently, nn.Sequential is used to stack all the convolution
  278. layers in subsampling, we need to rewrite it to make it work
  279. with cache, which is not prefered.
  280. Args:
  281. xs (torch.Tensor): (1, max_len, dim)
  282. chunk_size (int): decoding chunk size
  283. """
  284. assert decoding_chunk_size > 0
  285. # The model is trained by static or dynamic chunk
  286. assert self.static_chunk_size > 0 or self.use_dynamic_chunk
  287. subsampling = self.embed.subsampling_rate
  288. context = self.embed.right_context + 1 # Add current frame
  289. stride = subsampling * decoding_chunk_size
  290. decoding_window = (decoding_chunk_size - 1) * subsampling + context
  291. num_frames = xs.size(1)
  292. att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
  293. cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
  294. outputs = []
  295. offset = 0
  296. required_cache_size = decoding_chunk_size * num_decoding_left_chunks
  297. # Feed forward overlap input step by step
  298. for cur in range(0, num_frames - context + 1, stride):
  299. end = min(cur + decoding_window, num_frames)
  300. chunk_xs = xs[:, cur:end, :]
  301. (y, att_cache,
  302. cnn_cache) = self.forward_chunk(chunk_xs, offset,
  303. required_cache_size, att_cache,
  304. cnn_cache)
  305. outputs.append(y)
  306. offset += y.size(1)
  307. ys = torch.cat(outputs, 1)
  308. masks = torch.ones((1, 1, ys.size(1)),
  309. device=ys.device,
  310. dtype=torch.bool)
  311. return ys, masks
  312. class TransformerEncoder(BaseEncoder):
  313. """Transformer encoder module."""
  314. def __init__(
  315. self,
  316. input_size: int,
  317. output_size: int = 256,
  318. attention_heads: int = 4,
  319. linear_units: int = 2048,
  320. num_blocks: int = 6,
  321. dropout_rate: float = 0.1,
  322. positional_dropout_rate: float = 0.1,
  323. attention_dropout_rate: float = 0.0,
  324. input_layer: str = "conv2d",
  325. pos_enc_layer_type: str = "abs_pos",
  326. normalize_before: bool = True,
  327. static_chunk_size: int = 0,
  328. use_dynamic_chunk: bool = False,
  329. global_cmvn: torch.nn.Module = None,
  330. use_dynamic_left_chunk: bool = False,
  331. key_bias: bool = True,
  332. selfattention_layer_type: str = "selfattn",
  333. activation_type: str = "relu",
  334. gradient_checkpointing: bool = False,
  335. ):
  336. """ Construct TransformerEncoder
  337. See Encoder for the meaning of each parameter.
  338. """
  339. super().__init__(input_size, output_size, attention_heads,
  340. linear_units, num_blocks, dropout_rate,
  341. positional_dropout_rate, attention_dropout_rate,
  342. input_layer, pos_enc_layer_type, normalize_before,
  343. static_chunk_size, use_dynamic_chunk, global_cmvn,
  344. use_dynamic_left_chunk, gradient_checkpointing)
  345. activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
  346. self.encoders = torch.nn.ModuleList([
  347. TransformerEncoderLayer(
  348. output_size,
  349. COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
  350. output_size,
  351. attention_dropout_rate,
  352. key_bias),
  353. PositionwiseFeedForward(output_size, linear_units,
  354. dropout_rate, activation),
  355. dropout_rate, normalize_before) for _ in range(num_blocks)
  356. ])
  357. class ConformerEncoder(BaseEncoder):
  358. """Conformer encoder module."""
  359. def __init__(
  360. self,
  361. input_size: int,
  362. output_size: int = 256,
  363. attention_heads: int = 4,
  364. linear_units: int = 2048,
  365. num_blocks: int = 6,
  366. dropout_rate: float = 0.1,
  367. positional_dropout_rate: float = 0.1,
  368. attention_dropout_rate: float = 0.0,
  369. input_layer: str = "conv2d",
  370. pos_enc_layer_type: str = "rel_pos",
  371. normalize_before: bool = True,
  372. static_chunk_size: int = 0,
  373. use_dynamic_chunk: bool = False,
  374. global_cmvn: torch.nn.Module = None,
  375. use_dynamic_left_chunk: bool = False,
  376. positionwise_conv_kernel_size: int = 1,
  377. macaron_style: bool = True,
  378. selfattention_layer_type: str = "rel_selfattn",
  379. activation_type: str = "swish",
  380. use_cnn_module: bool = True,
  381. cnn_module_kernel: int = 15,
  382. causal: bool = False,
  383. cnn_module_norm: str = "batch_norm",
  384. key_bias: bool = True,
  385. gradient_checkpointing: bool = False,
  386. ):
  387. """Construct ConformerEncoder
  388. Args:
  389. input_size to use_dynamic_chunk, see in BaseEncoder
  390. positionwise_conv_kernel_size (int): Kernel size of positionwise
  391. conv1d layer.
  392. macaron_style (bool): Whether to use macaron style for
  393. positionwise layer.
  394. selfattention_layer_type (str): Encoder attention layer type,
  395. the parameter has no effect now, it's just for configure
  396. compatibility.
  397. activation_type (str): Encoder activation function type.
  398. use_cnn_module (bool): Whether to use convolution module.
  399. cnn_module_kernel (int): Kernel size of convolution module.
  400. causal (bool): whether to use causal convolution or not.
  401. key_bias: whether use bias in attention.linear_k, False for whisper models.
  402. """
  403. super().__init__(input_size, output_size, attention_heads,
  404. linear_units, num_blocks, dropout_rate,
  405. positional_dropout_rate, attention_dropout_rate,
  406. input_layer, pos_enc_layer_type, normalize_before,
  407. static_chunk_size, use_dynamic_chunk, global_cmvn,
  408. use_dynamic_left_chunk, gradient_checkpointing)
  409. activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
  410. # self-attention module definition
  411. encoder_selfattn_layer_args = (
  412. attention_heads,
  413. output_size,
  414. attention_dropout_rate,
  415. key_bias,
  416. )
  417. # feed-forward module definition
  418. positionwise_layer_args = (
  419. output_size,
  420. linear_units,
  421. dropout_rate,
  422. activation,
  423. )
  424. # convolution module definition
  425. convolution_layer_args = (output_size, cnn_module_kernel, activation,
  426. cnn_module_norm, causal)
  427. self.encoders = torch.nn.ModuleList([
  428. ConformerEncoderLayer(
  429. output_size,
  430. COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
  431. *encoder_selfattn_layer_args),
  432. PositionwiseFeedForward(*positionwise_layer_args),
  433. PositionwiseFeedForward(
  434. *positionwise_layer_args) if macaron_style else None,
  435. ConvolutionModule(
  436. *convolution_layer_args) if use_cnn_module else None,
  437. dropout_rate,
  438. normalize_before,
  439. ) for _ in range(num_blocks)
  440. ])