1
0

encoder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
  3. # 2024 Alibaba Inc (Xiang Lyu)
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # Modified from ESPnet(https://github.com/espnet/espnet)
  17. """Encoder definition."""
  18. from typing import Tuple
  19. import torch
  20. import torch.utils.checkpoint as ckpt
  21. from cosyvoice.transformer.convolution import ConvolutionModule
  22. from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
  23. from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
  24. from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
  25. from cosyvoice.utils.class_utils import (
  26. COSYVOICE_EMB_CLASSES,
  27. COSYVOICE_SUBSAMPLE_CLASSES,
  28. COSYVOICE_ATTENTION_CLASSES,
  29. COSYVOICE_ACTIVATION_CLASSES,
  30. )
  31. from cosyvoice.utils.mask import make_pad_mask
  32. from cosyvoice.utils.mask import add_optional_chunk_mask
  33. class BaseEncoder(torch.nn.Module):
  34. def __init__(
  35. self,
  36. input_size: int,
  37. output_size: int = 256,
  38. attention_heads: int = 4,
  39. linear_units: int = 2048,
  40. num_blocks: int = 6,
  41. dropout_rate: float = 0.1,
  42. positional_dropout_rate: float = 0.1,
  43. attention_dropout_rate: float = 0.0,
  44. input_layer: str = "conv2d",
  45. pos_enc_layer_type: str = "abs_pos",
  46. normalize_before: bool = True,
  47. static_chunk_size: int = 0,
  48. use_dynamic_chunk: bool = False,
  49. global_cmvn: torch.nn.Module = None,
  50. use_dynamic_left_chunk: bool = False,
  51. gradient_checkpointing: bool = False,
  52. ):
  53. """
  54. Args:
  55. input_size (int): input dim
  56. output_size (int): dimension of attention
  57. attention_heads (int): the number of heads of multi head attention
  58. linear_units (int): the hidden units number of position-wise feed
  59. forward
  60. num_blocks (int): the number of decoder blocks
  61. dropout_rate (float): dropout rate
  62. attention_dropout_rate (float): dropout rate in attention
  63. positional_dropout_rate (float): dropout rate after adding
  64. positional encoding
  65. input_layer (str): input layer type.
  66. optional [linear, conv2d, conv2d6, conv2d8]
  67. pos_enc_layer_type (str): Encoder positional encoding layer type.
  68. opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
  69. normalize_before (bool):
  70. True: use layer_norm before each sub-block of a layer.
  71. False: use layer_norm after each sub-block of a layer.
  72. static_chunk_size (int): chunk size for static chunk training and
  73. decoding
  74. use_dynamic_chunk (bool): whether use dynamic chunk size for
  75. training or not, You can only use fixed chunk(chunk_size > 0)
  76. or dyanmic chunk size(use_dynamic_chunk = True)
  77. global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
  78. use_dynamic_left_chunk (bool): whether use dynamic left chunk in
  79. dynamic chunk training
  80. key_bias: whether use bias in attention.linear_k, False for whisper models.
  81. gradient_checkpointing: rerunning a forward-pass segment for each
  82. checkpointed segment during backward.
  83. """
  84. super().__init__()
  85. self._output_size = output_size
  86. self.global_cmvn = global_cmvn
  87. self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
  88. input_size,
  89. output_size,
  90. dropout_rate,
  91. COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
  92. positional_dropout_rate),
  93. )
  94. self.normalize_before = normalize_before
  95. self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
  96. self.static_chunk_size = static_chunk_size
  97. self.use_dynamic_chunk = use_dynamic_chunk
  98. self.use_dynamic_left_chunk = use_dynamic_left_chunk
  99. self.gradient_checkpointing = gradient_checkpointing
  100. def output_size(self) -> int:
  101. return self._output_size
  102. def forward(
  103. self,
  104. xs: torch.Tensor,
  105. xs_lens: torch.Tensor,
  106. decoding_chunk_size: int = 0,
  107. num_decoding_left_chunks: int = -1,
  108. ) -> Tuple[torch.Tensor, torch.Tensor]:
  109. """Embed positions in tensor.
  110. Args:
  111. xs: padded input tensor (B, T, D)
  112. xs_lens: input length (B)
  113. decoding_chunk_size: decoding chunk size for dynamic chunk
  114. 0: default for training, use random dynamic chunk.
  115. <0: for decoding, use full chunk.
  116. >0: for decoding, use fixed chunk size as set.
  117. num_decoding_left_chunks: number of left chunks, this is for decoding,
  118. the chunk size is decoding_chunk_size.
  119. >=0: use num_decoding_left_chunks
  120. <0: use all left chunks
  121. Returns:
  122. encoder output tensor xs, and subsampled masks
  123. xs: padded output tensor (B, T' ~= T/subsample_rate, D)
  124. masks: torch.Tensor batch padding mask after subsample
  125. (B, 1, T' ~= T/subsample_rate)
  126. NOTE(xcsong):
  127. We pass the `__call__` method of the modules instead of `forward` to the
  128. checkpointing API because `__call__` attaches all the hooks of the module.
  129. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  130. """
  131. T = xs.size(1)
  132. masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
  133. if self.global_cmvn is not None:
  134. xs = self.global_cmvn(xs)
  135. xs, pos_emb, masks = self.embed(xs, masks)
  136. mask_pad = masks # (B, 1, T/subsample_rate)
  137. chunk_masks = add_optional_chunk_mask(xs, masks,
  138. self.use_dynamic_chunk,
  139. self.use_dynamic_left_chunk,
  140. decoding_chunk_size,
  141. self.static_chunk_size,
  142. num_decoding_left_chunks)
  143. if self.gradient_checkpointing and self.training:
  144. xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
  145. mask_pad)
  146. else:
  147. xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
  148. if self.normalize_before:
  149. xs = self.after_norm(xs)
  150. # Here we assume the mask is not changed in encoder layers, so just
  151. # return the masks before encoder layers, and the masks will be used
  152. # for cross attention with decoder later
  153. return xs, masks
  154. def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
  155. pos_emb: torch.Tensor,
  156. mask_pad: torch.Tensor) -> torch.Tensor:
  157. for layer in self.encoders:
  158. xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
  159. return xs
  160. @torch.jit.unused
  161. def forward_layers_checkpointed(self, xs: torch.Tensor,
  162. chunk_masks: torch.Tensor,
  163. pos_emb: torch.Tensor,
  164. mask_pad: torch.Tensor) -> torch.Tensor:
  165. for layer in self.encoders:
  166. xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
  167. chunk_masks, pos_emb,
  168. mask_pad)
  169. return xs
  170. @torch.jit.export
  171. def forward_chunk(
  172. self,
  173. xs: torch.Tensor,
  174. offset: int,
  175. required_cache_size: int,
  176. att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  177. cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  178. att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  179. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  180. """ Forward just one chunk
  181. Args:
  182. xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
  183. where `time == (chunk_size - 1) * subsample_rate + \
  184. subsample.right_context + 1`
  185. offset (int): current offset in encoder output time stamp
  186. required_cache_size (int): cache size required for next chunk
  187. compuation
  188. >=0: actual cache size
  189. <0: means all history cache is required
  190. att_cache (torch.Tensor): cache tensor for KEY & VALUE in
  191. transformer/conformer attention, with shape
  192. (elayers, head, cache_t1, d_k * 2), where
  193. `head * d_k == hidden-dim` and
  194. `cache_t1 == chunk_size * num_decoding_left_chunks`.
  195. cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
  196. (elayers, b=1, hidden-dim, cache_t2), where
  197. `cache_t2 == cnn.lorder - 1`
  198. Returns:
  199. torch.Tensor: output of current input xs,
  200. with shape (b=1, chunk_size, hidden-dim).
  201. torch.Tensor: new attention cache required for next chunk, with
  202. dynamic shape (elayers, head, ?, d_k * 2)
  203. depending on required_cache_size.
  204. torch.Tensor: new conformer cnn cache required for next chunk, with
  205. same shape as the original cnn_cache.
  206. """
  207. assert xs.size(0) == 1
  208. # tmp_masks is just for interface compatibility
  209. tmp_masks = torch.ones(1,
  210. xs.size(1),
  211. device=xs.device,
  212. dtype=torch.bool)
  213. tmp_masks = tmp_masks.unsqueeze(1)
  214. if self.global_cmvn is not None:
  215. xs = self.global_cmvn(xs)
  216. # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
  217. xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
  218. # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
  219. elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
  220. chunk_size = xs.size(1)
  221. attention_key_size = cache_t1 + chunk_size
  222. pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
  223. size=attention_key_size)
  224. if required_cache_size < 0:
  225. next_cache_start = 0
  226. elif required_cache_size == 0:
  227. next_cache_start = attention_key_size
  228. else:
  229. next_cache_start = max(attention_key_size - required_cache_size, 0)
  230. r_att_cache = []
  231. r_cnn_cache = []
  232. for i, layer in enumerate(self.encoders):
  233. # NOTE(xcsong): Before layer.forward
  234. # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
  235. # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
  236. xs, _, new_att_cache, new_cnn_cache = layer(
  237. xs,
  238. att_mask,
  239. pos_emb,
  240. att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
  241. cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
  242. # NOTE(xcsong): After layer.forward
  243. # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
  244. # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
  245. r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
  246. r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
  247. if self.normalize_before:
  248. xs = self.after_norm(xs)
  249. # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
  250. # ? may be larger than cache_t1, it depends on required_cache_size
  251. r_att_cache = torch.cat(r_att_cache, dim=0)
  252. # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
  253. r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
  254. return (xs, r_att_cache, r_cnn_cache)
  255. @torch.jit.unused
  256. def forward_chunk_by_chunk(
  257. self,
  258. xs: torch.Tensor,
  259. decoding_chunk_size: int,
  260. num_decoding_left_chunks: int = -1,
  261. ) -> Tuple[torch.Tensor, torch.Tensor]:
  262. """ Forward input chunk by chunk with chunk_size like a streaming
  263. fashion
  264. Here we should pay special attention to computation cache in the
  265. streaming style forward chunk by chunk. Three things should be taken
  266. into account for computation in the current network:
  267. 1. transformer/conformer encoder layers output cache
  268. 2. convolution in conformer
  269. 3. convolution in subsampling
  270. However, we don't implement subsampling cache for:
  271. 1. We can control subsampling module to output the right result by
  272. overlapping input instead of cache left context, even though it
  273. wastes some computation, but subsampling only takes a very
  274. small fraction of computation in the whole model.
  275. 2. Typically, there are several covolution layers with subsampling
  276. in subsampling module, it is tricky and complicated to do cache
  277. with different convolution layers with different subsampling
  278. rate.
  279. 3. Currently, nn.Sequential is used to stack all the convolution
  280. layers in subsampling, we need to rewrite it to make it work
  281. with cache, which is not preferred.
  282. Args:
  283. xs (torch.Tensor): (1, max_len, dim)
  284. chunk_size (int): decoding chunk size
  285. """
  286. assert decoding_chunk_size > 0
  287. # The model is trained by static or dynamic chunk
  288. assert self.static_chunk_size > 0 or self.use_dynamic_chunk
  289. subsampling = self.embed.subsampling_rate
  290. context = self.embed.right_context + 1 # Add current frame
  291. stride = subsampling * decoding_chunk_size
  292. decoding_window = (decoding_chunk_size - 1) * subsampling + context
  293. num_frames = xs.size(1)
  294. att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
  295. cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
  296. outputs = []
  297. offset = 0
  298. required_cache_size = decoding_chunk_size * num_decoding_left_chunks
  299. # Feed forward overlap input step by step
  300. for cur in range(0, num_frames - context + 1, stride):
  301. end = min(cur + decoding_window, num_frames)
  302. chunk_xs = xs[:, cur:end, :]
  303. (y, att_cache,
  304. cnn_cache) = self.forward_chunk(chunk_xs, offset,
  305. required_cache_size, att_cache,
  306. cnn_cache)
  307. outputs.append(y)
  308. offset += y.size(1)
  309. ys = torch.cat(outputs, 1)
  310. masks = torch.ones((1, 1, ys.size(1)),
  311. device=ys.device,
  312. dtype=torch.bool)
  313. return ys, masks
  314. class TransformerEncoder(BaseEncoder):
  315. """Transformer encoder module."""
  316. def __init__(
  317. self,
  318. input_size: int,
  319. output_size: int = 256,
  320. attention_heads: int = 4,
  321. linear_units: int = 2048,
  322. num_blocks: int = 6,
  323. dropout_rate: float = 0.1,
  324. positional_dropout_rate: float = 0.1,
  325. attention_dropout_rate: float = 0.0,
  326. input_layer: str = "conv2d",
  327. pos_enc_layer_type: str = "abs_pos",
  328. normalize_before: bool = True,
  329. static_chunk_size: int = 0,
  330. use_dynamic_chunk: bool = False,
  331. global_cmvn: torch.nn.Module = None,
  332. use_dynamic_left_chunk: bool = False,
  333. key_bias: bool = True,
  334. selfattention_layer_type: str = "selfattn",
  335. activation_type: str = "relu",
  336. gradient_checkpointing: bool = False,
  337. ):
  338. """ Construct TransformerEncoder
  339. See Encoder for the meaning of each parameter.
  340. """
  341. super().__init__(input_size, output_size, attention_heads,
  342. linear_units, num_blocks, dropout_rate,
  343. positional_dropout_rate, attention_dropout_rate,
  344. input_layer, pos_enc_layer_type, normalize_before,
  345. static_chunk_size, use_dynamic_chunk, global_cmvn,
  346. use_dynamic_left_chunk, gradient_checkpointing)
  347. activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
  348. self.encoders = torch.nn.ModuleList([
  349. TransformerEncoderLayer(
  350. output_size,
  351. COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
  352. output_size,
  353. attention_dropout_rate,
  354. key_bias),
  355. PositionwiseFeedForward(output_size, linear_units,
  356. dropout_rate, activation),
  357. dropout_rate, normalize_before) for _ in range(num_blocks)
  358. ])
  359. class ConformerEncoder(BaseEncoder):
  360. """Conformer encoder module."""
  361. def __init__(
  362. self,
  363. input_size: int,
  364. output_size: int = 256,
  365. attention_heads: int = 4,
  366. linear_units: int = 2048,
  367. num_blocks: int = 6,
  368. dropout_rate: float = 0.1,
  369. positional_dropout_rate: float = 0.1,
  370. attention_dropout_rate: float = 0.0,
  371. input_layer: str = "conv2d",
  372. pos_enc_layer_type: str = "rel_pos",
  373. normalize_before: bool = True,
  374. static_chunk_size: int = 0,
  375. use_dynamic_chunk: bool = False,
  376. global_cmvn: torch.nn.Module = None,
  377. use_dynamic_left_chunk: bool = False,
  378. positionwise_conv_kernel_size: int = 1,
  379. macaron_style: bool = True,
  380. selfattention_layer_type: str = "rel_selfattn",
  381. activation_type: str = "swish",
  382. use_cnn_module: bool = True,
  383. cnn_module_kernel: int = 15,
  384. causal: bool = False,
  385. cnn_module_norm: str = "batch_norm",
  386. key_bias: bool = True,
  387. gradient_checkpointing: bool = False,
  388. ):
  389. """Construct ConformerEncoder
  390. Args:
  391. input_size to use_dynamic_chunk, see in BaseEncoder
  392. positionwise_conv_kernel_size (int): Kernel size of positionwise
  393. conv1d layer.
  394. macaron_style (bool): Whether to use macaron style for
  395. positionwise layer.
  396. selfattention_layer_type (str): Encoder attention layer type,
  397. the parameter has no effect now, it's just for configure
  398. compatibility.
  399. activation_type (str): Encoder activation function type.
  400. use_cnn_module (bool): Whether to use convolution module.
  401. cnn_module_kernel (int): Kernel size of convolution module.
  402. causal (bool): whether to use causal convolution or not.
  403. key_bias: whether use bias in attention.linear_k, False for whisper models.
  404. """
  405. super().__init__(input_size, output_size, attention_heads,
  406. linear_units, num_blocks, dropout_rate,
  407. positional_dropout_rate, attention_dropout_rate,
  408. input_layer, pos_enc_layer_type, normalize_before,
  409. static_chunk_size, use_dynamic_chunk, global_cmvn,
  410. use_dynamic_left_chunk, gradient_checkpointing)
  411. activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
  412. # self-attention module definition
  413. encoder_selfattn_layer_args = (
  414. attention_heads,
  415. output_size,
  416. attention_dropout_rate,
  417. key_bias,
  418. )
  419. # feed-forward module definition
  420. positionwise_layer_args = (
  421. output_size,
  422. linear_units,
  423. dropout_rate,
  424. activation,
  425. )
  426. # convolution module definition
  427. convolution_layer_args = (output_size, cnn_module_kernel, activation,
  428. cnn_module_norm, causal)
  429. self.encoders = torch.nn.ModuleList([
  430. ConformerEncoderLayer(
  431. output_size,
  432. COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
  433. *encoder_selfattn_layer_args),
  434. PositionwiseFeedForward(*positionwise_layer_args),
  435. PositionwiseFeedForward(
  436. *positionwise_layer_args) if macaron_style else None,
  437. ConvolutionModule(
  438. *convolution_layer_args) if use_cnn_module else None,
  439. dropout_rate,
  440. normalize_before,
  441. ) for _ in range(num_blocks)
  442. ])