1
0

attention.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # Copyright (c) 2019 Shigeki Karita
  2. # 2020 Mobvoi Inc (Binbin Zhang)
  3. # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
  4. # 2024 Alibaba Inc (Xiang Lyu)
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """Multi-Head Attention layer definition."""
  18. import math
  19. from typing import Tuple
  20. import torch
  21. from torch import nn
  22. class MultiHeadedAttention(nn.Module):
  23. """Multi-Head Attention layer.
  24. Args:
  25. n_head (int): The number of heads.
  26. n_feat (int): The number of features.
  27. dropout_rate (float): Dropout rate.
  28. """
  29. def __init__(self,
  30. n_head: int,
  31. n_feat: int,
  32. dropout_rate: float,
  33. key_bias: bool = True):
  34. """Construct an MultiHeadedAttention object."""
  35. super().__init__()
  36. assert n_feat % n_head == 0
  37. # We assume d_v always equals d_k
  38. self.d_k = n_feat // n_head
  39. self.h = n_head
  40. self.linear_q = nn.Linear(n_feat, n_feat)
  41. self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
  42. self.linear_v = nn.Linear(n_feat, n_feat)
  43. self.linear_out = nn.Linear(n_feat, n_feat)
  44. self.dropout = nn.Dropout(p=dropout_rate)
  45. def forward_qkv(
  46. self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
  47. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  48. """Transform query, key and value.
  49. Args:
  50. query (torch.Tensor): Query tensor (#batch, time1, size).
  51. key (torch.Tensor): Key tensor (#batch, time2, size).
  52. value (torch.Tensor): Value tensor (#batch, time2, size).
  53. Returns:
  54. torch.Tensor: Transformed query tensor, size
  55. (#batch, n_head, time1, d_k).
  56. torch.Tensor: Transformed key tensor, size
  57. (#batch, n_head, time2, d_k).
  58. torch.Tensor: Transformed value tensor, size
  59. (#batch, n_head, time2, d_k).
  60. """
  61. n_batch = query.size(0)
  62. q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
  63. k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
  64. v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
  65. q = q.transpose(1, 2) # (batch, head, time1, d_k)
  66. k = k.transpose(1, 2) # (batch, head, time2, d_k)
  67. v = v.transpose(1, 2) # (batch, head, time2, d_k)
  68. return q, k, v
  69. def forward_attention(
  70. self,
  71. value: torch.Tensor,
  72. scores: torch.Tensor,
  73. mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
  74. ) -> torch.Tensor:
  75. """Compute attention context vector.
  76. Args:
  77. value (torch.Tensor): Transformed value, size
  78. (#batch, n_head, time2, d_k).
  79. scores (torch.Tensor): Attention score, size
  80. (#batch, n_head, time1, time2).
  81. mask (torch.Tensor): Mask, size (#batch, 1, time2) or
  82. (#batch, time1, time2), (0, 0, 0) means fake mask.
  83. Returns:
  84. torch.Tensor: Transformed value (#batch, time1, d_model)
  85. weighted by the attention score (#batch, time1, time2).
  86. """
  87. n_batch = value.size(0)
  88. # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
  89. # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
  90. # 1st chunk to ease the onnx export.]
  91. # 2. pytorch training
  92. if mask.size(2) > 0: # time2 > 0
  93. mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
  94. # For last chunk, time2 might be larger than scores.size(-1)
  95. mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
  96. scores = scores.masked_fill(mask, -float('inf'))
  97. attn = torch.softmax(scores, dim=-1).masked_fill(
  98. mask, 0.0) # (batch, head, time1, time2)
  99. # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
  100. # 1. onnx(16/-1, -1/-1, 16/0)
  101. # 2. jit (16/-1, -1/-1, 16/0, 16/4)
  102. else:
  103. attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
  104. p_attn = self.dropout(attn)
  105. x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
  106. x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
  107. self.h * self.d_k)
  108. ) # (batch, time1, d_model)
  109. return self.linear_out(x) # (batch, time1, d_model)
  110. def forward(
  111. self,
  112. query: torch.Tensor,
  113. key: torch.Tensor,
  114. value: torch.Tensor,
  115. mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  116. pos_emb: torch.Tensor = torch.empty(0),
  117. cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
  118. ) -> Tuple[torch.Tensor, torch.Tensor]:
  119. """Compute scaled dot product attention.
  120. Args:
  121. query (torch.Tensor): Query tensor (#batch, time1, size).
  122. key (torch.Tensor): Key tensor (#batch, time2, size).
  123. value (torch.Tensor): Value tensor (#batch, time2, size).
  124. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  125. (#batch, time1, time2).
  126. 1.When applying cross attention between decoder and encoder,
  127. the batch padding mask for input is in (#batch, 1, T) shape.
  128. 2.When applying self attention of encoder,
  129. the mask is in (#batch, T, T) shape.
  130. 3.When applying self attention of decoder,
  131. the mask is in (#batch, L, L) shape.
  132. 4.If the different position in decoder see different block
  133. of the encoder, such as Mocha, the passed in mask could be
  134. in (#batch, L, T) shape. But there is no such case in current
  135. CosyVoice.
  136. cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
  137. where `cache_t == chunk_size * num_decoding_left_chunks`
  138. and `head * d_k == size`
  139. Returns:
  140. torch.Tensor: Output tensor (#batch, time1, d_model).
  141. torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
  142. where `cache_t == chunk_size * num_decoding_left_chunks`
  143. and `head * d_k == size`
  144. """
  145. q, k, v = self.forward_qkv(query, key, value)
  146. # NOTE(xcsong):
  147. # when export onnx model, for 1st chunk, we feed
  148. # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
  149. # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
  150. # In all modes, `if cache.size(0) > 0` will alwayse be `True`
  151. # and we will always do splitting and
  152. # concatnation(this will simplify onnx export). Note that
  153. # it's OK to concat & split zero-shaped tensors(see code below).
  154. # when export jit model, for 1st chunk, we always feed
  155. # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
  156. # >>> a = torch.ones((1, 2, 0, 4))
  157. # >>> b = torch.ones((1, 2, 3, 4))
  158. # >>> c = torch.cat((a, b), dim=2)
  159. # >>> torch.equal(b, c) # True
  160. # >>> d = torch.split(a, 2, dim=-1)
  161. # >>> torch.equal(d[0], d[1]) # True
  162. if cache.size(0) > 0:
  163. key_cache, value_cache = torch.split(cache,
  164. cache.size(-1) // 2,
  165. dim=-1)
  166. k = torch.cat([key_cache, k], dim=2)
  167. v = torch.cat([value_cache, v], dim=2)
  168. # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
  169. # non-trivial to calculate `next_cache_start` here.
  170. new_cache = torch.cat((k, v), dim=-1)
  171. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
  172. return self.forward_attention(v, scores, mask), new_cache
  173. class RelPositionMultiHeadedAttention(MultiHeadedAttention):
  174. """Multi-Head Attention layer with relative position encoding.
  175. Paper: https://arxiv.org/abs/1901.02860
  176. Args:
  177. n_head (int): The number of heads.
  178. n_feat (int): The number of features.
  179. dropout_rate (float): Dropout rate.
  180. """
  181. def __init__(self,
  182. n_head: int,
  183. n_feat: int,
  184. dropout_rate: float,
  185. key_bias: bool = True):
  186. """Construct an RelPositionMultiHeadedAttention object."""
  187. super().__init__(n_head, n_feat, dropout_rate, key_bias)
  188. # linear transformation for positional encoding
  189. self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
  190. # these two learnable bias are used in matrix c and matrix d
  191. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  192. self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
  193. self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
  194. torch.nn.init.xavier_uniform_(self.pos_bias_u)
  195. torch.nn.init.xavier_uniform_(self.pos_bias_v)
  196. def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
  197. """Compute relative positional encoding.
  198. Args:
  199. x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
  200. time1 means the length of query vector.
  201. Returns:
  202. torch.Tensor: Output tensor.
  203. """
  204. zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
  205. device=x.device,
  206. dtype=x.dtype)
  207. x_padded = torch.cat([zero_pad, x], dim=-1)
  208. x_padded = x_padded.view(x.size()[0],
  209. x.size()[1],
  210. x.size(3) + 1, x.size(2))
  211. x = x_padded[:, :, 1:].view_as(x)[
  212. :, :, :, : x.size(-1) // 2 + 1
  213. ] # only keep the positions from 0 to time2
  214. return x
  215. def forward(
  216. self,
  217. query: torch.Tensor,
  218. key: torch.Tensor,
  219. value: torch.Tensor,
  220. mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  221. pos_emb: torch.Tensor = torch.empty(0),
  222. cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
  223. ) -> Tuple[torch.Tensor, torch.Tensor]:
  224. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
  225. Args:
  226. query (torch.Tensor): Query tensor (#batch, time1, size).
  227. key (torch.Tensor): Key tensor (#batch, time2, size).
  228. value (torch.Tensor): Value tensor (#batch, time2, size).
  229. mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
  230. (#batch, time1, time2), (0, 0, 0) means fake mask.
  231. pos_emb (torch.Tensor): Positional embedding tensor
  232. (#batch, time2, size).
  233. cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
  234. where `cache_t == chunk_size * num_decoding_left_chunks`
  235. and `head * d_k == size`
  236. Returns:
  237. torch.Tensor: Output tensor (#batch, time1, d_model).
  238. torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
  239. where `cache_t == chunk_size * num_decoding_left_chunks`
  240. and `head * d_k == size`
  241. """
  242. q, k, v = self.forward_qkv(query, key, value)
  243. q = q.transpose(1, 2) # (batch, time1, head, d_k)
  244. # NOTE(xcsong):
  245. # when export onnx model, for 1st chunk, we feed
  246. # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
  247. # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
  248. # In all modes, `if cache.size(0) > 0` will alwayse be `True`
  249. # and we will always do splitting and
  250. # concatnation(this will simplify onnx export). Note that
  251. # it's OK to concat & split zero-shaped tensors(see code below).
  252. # when export jit model, for 1st chunk, we always feed
  253. # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
  254. # >>> a = torch.ones((1, 2, 0, 4))
  255. # >>> b = torch.ones((1, 2, 3, 4))
  256. # >>> c = torch.cat((a, b), dim=2)
  257. # >>> torch.equal(b, c) # True
  258. # >>> d = torch.split(a, 2, dim=-1)
  259. # >>> torch.equal(d[0], d[1]) # True
  260. if cache.size(0) > 0:
  261. key_cache, value_cache = torch.split(cache,
  262. cache.size(-1) // 2,
  263. dim=-1)
  264. k = torch.cat([key_cache, k], dim=2)
  265. v = torch.cat([value_cache, v], dim=2)
  266. # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
  267. # non-trivial to calculate `next_cache_start` here.
  268. new_cache = torch.cat((k, v), dim=-1)
  269. n_batch_pos = pos_emb.size(0)
  270. p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
  271. p = p.transpose(1, 2) # (batch, head, time1, d_k)
  272. # (batch, head, time1, d_k)
  273. q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
  274. # (batch, head, time1, d_k)
  275. q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
  276. # compute attention score
  277. # first compute matrix a and matrix c
  278. # as described in https://arxiv.org/abs/1901.02860 Section 3.3
  279. # (batch, head, time1, time2)
  280. matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
  281. # compute matrix b and matrix d
  282. # (batch, head, time1, time2)
  283. matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
  284. # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
  285. if matrix_ac.shape != matrix_bd.shape:
  286. matrix_bd = self.rel_shift(matrix_bd)
  287. scores = (matrix_ac + matrix_bd) / math.sqrt(
  288. self.d_k) # (batch, head, time1, time2)
  289. return self.forward_attention(v, scores, mask), new_cache