decoder.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
  2. # 2024 Alibaba Inc (Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Modified from ESPnet(https://github.com/espnet/espnet)
  16. """Decoder definition."""
  17. from typing import Tuple, List, Optional
  18. import torch
  19. import torch.utils.checkpoint as ckpt
  20. import logging
  21. from cosyvoice.transformer.decoder_layer import DecoderLayer
  22. from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
  23. from cosyvoice.utils.class_utils import (
  24. COSYVOICE_EMB_CLASSES,
  25. COSYVOICE_ATTENTION_CLASSES,
  26. COSYVOICE_ACTIVATION_CLASSES,
  27. )
  28. from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
  29. class TransformerDecoder(torch.nn.Module):
  30. """Base class of Transfomer decoder module.
  31. Args:
  32. vocab_size: output dim
  33. encoder_output_size: dimension of attention
  34. attention_heads: the number of heads of multi head attention
  35. linear_units: the hidden units number of position-wise feedforward
  36. num_blocks: the number of decoder blocks
  37. dropout_rate: dropout rate
  38. self_attention_dropout_rate: dropout rate for attention
  39. input_layer: input layer type
  40. use_output_layer: whether to use output layer
  41. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  42. normalize_before:
  43. True: use layer_norm before each sub-block of a layer.
  44. False: use layer_norm after each sub-block of a layer.
  45. src_attention: if false, encoder-decoder cross attention is not
  46. applied, such as CIF model
  47. key_bias: whether use bias in attention.linear_k, False for whisper models.
  48. gradient_checkpointing: rerunning a forward-pass segment for each
  49. checkpointed segment during backward.
  50. tie_word_embedding: Tie or clone module weights depending of whether we are
  51. using TorchScript or not
  52. """
  53. def __init__(
  54. self,
  55. vocab_size: int,
  56. encoder_output_size: int,
  57. attention_heads: int = 4,
  58. linear_units: int = 2048,
  59. num_blocks: int = 6,
  60. dropout_rate: float = 0.1,
  61. positional_dropout_rate: float = 0.1,
  62. self_attention_dropout_rate: float = 0.0,
  63. src_attention_dropout_rate: float = 0.0,
  64. input_layer: str = "embed",
  65. use_output_layer: bool = True,
  66. normalize_before: bool = True,
  67. src_attention: bool = True,
  68. key_bias: bool = True,
  69. activation_type: str = "relu",
  70. gradient_checkpointing: bool = False,
  71. tie_word_embedding: bool = False,
  72. ):
  73. super().__init__()
  74. attention_dim = encoder_output_size
  75. activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
  76. self.embed = torch.nn.Sequential(
  77. torch.nn.Identity() if input_layer == "no_pos" else
  78. torch.nn.Embedding(vocab_size, attention_dim),
  79. COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
  80. positional_dropout_rate),
  81. )
  82. self.normalize_before = normalize_before
  83. self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
  84. self.use_output_layer = use_output_layer
  85. if use_output_layer:
  86. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  87. else:
  88. self.output_layer = torch.nn.Identity()
  89. self.num_blocks = num_blocks
  90. self.decoders = torch.nn.ModuleList([
  91. DecoderLayer(
  92. attention_dim,
  93. COSYVOICE_ATTENTION_CLASSES["selfattn"](
  94. attention_heads, attention_dim,
  95. self_attention_dropout_rate, key_bias),
  96. COSYVOICE_ATTENTION_CLASSES["selfattn"](
  97. attention_heads, attention_dim, src_attention_dropout_rate,
  98. key_bias) if src_attention else None,
  99. PositionwiseFeedForward(attention_dim, linear_units,
  100. dropout_rate, activation),
  101. dropout_rate,
  102. normalize_before,
  103. ) for _ in range(self.num_blocks)
  104. ])
  105. self.gradient_checkpointing = gradient_checkpointing
  106. self.tie_word_embedding = tie_word_embedding
  107. def forward(
  108. self,
  109. memory: torch.Tensor,
  110. memory_mask: torch.Tensor,
  111. ys_in_pad: torch.Tensor,
  112. ys_in_lens: torch.Tensor,
  113. r_ys_in_pad: torch.Tensor = torch.empty(0),
  114. reverse_weight: float = 0.0,
  115. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  116. """Forward decoder.
  117. Args:
  118. memory: encoded memory, float32 (batch, maxlen_in, feat)
  119. memory_mask: encoder memory mask, (batch, 1, maxlen_in)
  120. ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
  121. ys_in_lens: input lengths of this batch (batch)
  122. r_ys_in_pad: not used in transformer decoder, in order to unify api
  123. with bidirectional decoder
  124. reverse_weight: not used in transformer decoder, in order to unify
  125. api with bidirectional decode
  126. Returns:
  127. (tuple): tuple containing:
  128. x: decoded token score before softmax (batch, maxlen_out,
  129. vocab_size) if use_output_layer is True,
  130. torch.tensor(0.0), in order to unify api with bidirectional decoder
  131. olens: (batch, )
  132. NOTE(xcsong):
  133. We pass the `__call__` method of the modules instead of `forward` to the
  134. checkpointing API because `__call__` attaches all the hooks of the module.
  135. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  136. """
  137. tgt = ys_in_pad
  138. maxlen = tgt.size(1)
  139. # tgt_mask: (B, 1, L)
  140. tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
  141. tgt_mask = tgt_mask.to(tgt.device)
  142. # m: (1, L, L)
  143. m = subsequent_mask(tgt_mask.size(-1),
  144. device=tgt_mask.device).unsqueeze(0)
  145. # tgt_mask: (B, L, L)
  146. tgt_mask = tgt_mask & m
  147. x, _ = self.embed(tgt)
  148. if self.gradient_checkpointing and self.training:
  149. x = self.forward_layers_checkpointed(x, tgt_mask, memory,
  150. memory_mask)
  151. else:
  152. x = self.forward_layers(x, tgt_mask, memory, memory_mask)
  153. if self.normalize_before:
  154. x = self.after_norm(x)
  155. if self.use_output_layer:
  156. x = self.output_layer(x)
  157. olens = tgt_mask.sum(1)
  158. return x, torch.tensor(0.0), olens
  159. def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
  160. memory: torch.Tensor,
  161. memory_mask: torch.Tensor) -> torch.Tensor:
  162. for layer in self.decoders:
  163. x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
  164. memory_mask)
  165. return x
  166. @torch.jit.ignore(drop=True)
  167. def forward_layers_checkpointed(self, x: torch.Tensor,
  168. tgt_mask: torch.Tensor,
  169. memory: torch.Tensor,
  170. memory_mask: torch.Tensor) -> torch.Tensor:
  171. for layer in self.decoders:
  172. x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
  173. layer.__call__, x, tgt_mask, memory, memory_mask)
  174. return x
  175. def forward_one_step(
  176. self,
  177. memory: torch.Tensor,
  178. memory_mask: torch.Tensor,
  179. tgt: torch.Tensor,
  180. tgt_mask: torch.Tensor,
  181. cache: Optional[List[torch.Tensor]] = None,
  182. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  183. """Forward one step.
  184. This is only used for decoding.
  185. Args:
  186. memory: encoded memory, float32 (batch, maxlen_in, feat)
  187. memory_mask: encoded memory mask, (batch, 1, maxlen_in)
  188. tgt: input token ids, int64 (batch, maxlen_out)
  189. tgt_mask: input token mask, (batch, maxlen_out)
  190. dtype=torch.uint8 in PyTorch 1.2-
  191. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  192. cache: cached output list of (batch, max_time_out-1, size)
  193. Returns:
  194. y, cache: NN output value and cache per `self.decoders`.
  195. y.shape` is (batch, maxlen_out, token)
  196. """
  197. x, _ = self.embed(tgt)
  198. new_cache = []
  199. for i, decoder in enumerate(self.decoders):
  200. if cache is None:
  201. c = None
  202. else:
  203. c = cache[i]
  204. x, tgt_mask, memory, memory_mask = decoder(x,
  205. tgt_mask,
  206. memory,
  207. memory_mask,
  208. cache=c)
  209. new_cache.append(x)
  210. if self.normalize_before:
  211. y = self.after_norm(x[:, -1])
  212. else:
  213. y = x[:, -1]
  214. if self.use_output_layer:
  215. y = torch.log_softmax(self.output_layer(y), dim=-1)
  216. return y, new_cache
  217. def tie_or_clone_weights(self, jit_mode: bool = True):
  218. """Tie or clone module weights (between word_emb and output_layer)
  219. depending of whether we are using TorchScript or not"""
  220. if not self.use_output_layer:
  221. return
  222. if jit_mode:
  223. logging.info("clone emb.weight to output.weight")
  224. self.output_layer.weight = torch.nn.Parameter(
  225. self.embed[0].weight.clone())
  226. else:
  227. logging.info("tie emb.weight with output.weight")
  228. self.output_layer.weight = self.embed[0].weight
  229. if getattr(self.output_layer, "bias", None) is not None:
  230. self.output_layer.bias.data = torch.nn.functional.pad(
  231. self.output_layer.bias.data,
  232. (
  233. 0,
  234. self.output_layer.weight.shape[0] -
  235. self.output_layer.bias.shape[0],
  236. ),
  237. "constant",
  238. 0,
  239. )
  240. class BiTransformerDecoder(torch.nn.Module):
  241. """Base class of Transfomer decoder module.
  242. Args:
  243. vocab_size: output dim
  244. encoder_output_size: dimension of attention
  245. attention_heads: the number of heads of multi head attention
  246. linear_units: the hidden units number of position-wise feedforward
  247. num_blocks: the number of decoder blocks
  248. r_num_blocks: the number of right to left decoder blocks
  249. dropout_rate: dropout rate
  250. self_attention_dropout_rate: dropout rate for attention
  251. input_layer: input layer type
  252. use_output_layer: whether to use output layer
  253. pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
  254. normalize_before:
  255. True: use layer_norm before each sub-block of a layer.
  256. False: use layer_norm after each sub-block of a layer.
  257. key_bias: whether use bias in attention.linear_k, False for whisper models.
  258. """
  259. def __init__(
  260. self,
  261. vocab_size: int,
  262. encoder_output_size: int,
  263. attention_heads: int = 4,
  264. linear_units: int = 2048,
  265. num_blocks: int = 6,
  266. r_num_blocks: int = 0,
  267. dropout_rate: float = 0.1,
  268. positional_dropout_rate: float = 0.1,
  269. self_attention_dropout_rate: float = 0.0,
  270. src_attention_dropout_rate: float = 0.0,
  271. input_layer: str = "embed",
  272. use_output_layer: bool = True,
  273. normalize_before: bool = True,
  274. key_bias: bool = True,
  275. gradient_checkpointing: bool = False,
  276. tie_word_embedding: bool = False,
  277. ):
  278. super().__init__()
  279. self.tie_word_embedding = tie_word_embedding
  280. self.left_decoder = TransformerDecoder(
  281. vocab_size,
  282. encoder_output_size,
  283. attention_heads,
  284. linear_units,
  285. num_blocks,
  286. dropout_rate,
  287. positional_dropout_rate,
  288. self_attention_dropout_rate,
  289. src_attention_dropout_rate,
  290. input_layer,
  291. use_output_layer,
  292. normalize_before,
  293. key_bias=key_bias,
  294. gradient_checkpointing=gradient_checkpointing,
  295. tie_word_embedding=tie_word_embedding)
  296. self.right_decoder = TransformerDecoder(
  297. vocab_size,
  298. encoder_output_size,
  299. attention_heads,
  300. linear_units,
  301. r_num_blocks,
  302. dropout_rate,
  303. positional_dropout_rate,
  304. self_attention_dropout_rate,
  305. src_attention_dropout_rate,
  306. input_layer,
  307. use_output_layer,
  308. normalize_before,
  309. key_bias=key_bias,
  310. gradient_checkpointing=gradient_checkpointing,
  311. tie_word_embedding=tie_word_embedding)
  312. def forward(
  313. self,
  314. memory: torch.Tensor,
  315. memory_mask: torch.Tensor,
  316. ys_in_pad: torch.Tensor,
  317. ys_in_lens: torch.Tensor,
  318. r_ys_in_pad: torch.Tensor,
  319. reverse_weight: float = 0.0,
  320. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  321. """Forward decoder.
  322. Args:
  323. memory: encoded memory, float32 (batch, maxlen_in, feat)
  324. memory_mask: encoder memory mask, (batch, 1, maxlen_in)
  325. ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
  326. ys_in_lens: input lengths of this batch (batch)
  327. r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
  328. used for right to left decoder
  329. reverse_weight: used for right to left decoder
  330. Returns:
  331. (tuple): tuple containing:
  332. x: decoded token score before softmax (batch, maxlen_out,
  333. vocab_size) if use_output_layer is True,
  334. r_x: x: decoded token score (right to left decoder)
  335. before softmax (batch, maxlen_out, vocab_size)
  336. if use_output_layer is True,
  337. olens: (batch, )
  338. """
  339. l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
  340. ys_in_lens)
  341. r_x = torch.tensor(0.0)
  342. if reverse_weight > 0.0:
  343. r_x, _, olens = self.right_decoder(memory, memory_mask,
  344. r_ys_in_pad, ys_in_lens)
  345. return l_x, r_x, olens
  346. def forward_one_step(
  347. self,
  348. memory: torch.Tensor,
  349. memory_mask: torch.Tensor,
  350. tgt: torch.Tensor,
  351. tgt_mask: torch.Tensor,
  352. cache: Optional[List[torch.Tensor]] = None,
  353. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  354. """Forward one step.
  355. This is only used for decoding.
  356. Args:
  357. memory: encoded memory, float32 (batch, maxlen_in, feat)
  358. memory_mask: encoded memory mask, (batch, 1, maxlen_in)
  359. tgt: input token ids, int64 (batch, maxlen_out)
  360. tgt_mask: input token mask, (batch, maxlen_out)
  361. dtype=torch.uint8 in PyTorch 1.2-
  362. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  363. cache: cached output list of (batch, max_time_out-1, size)
  364. Returns:
  365. y, cache: NN output value and cache per `self.decoders`.
  366. y.shape` is (batch, maxlen_out, token)
  367. """
  368. return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
  369. tgt_mask, cache)
  370. def tie_or_clone_weights(self, jit_mode: bool = True):
  371. """Tie or clone module weights (between word_emb and output_layer)
  372. depending of whether we are using TorchScript or not"""
  373. self.left_decoder.tie_or_clone_weights(jit_mode)
  374. self.right_decoder.tie_or_clone_weights(jit_mode)