embedding.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
  2. # 2024 Alibaba Inc (Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Modified from ESPnet(https://github.com/espnet/espnet)
  16. """Positonal Encoding Module."""
  17. import math
  18. from typing import Tuple, Union
  19. import torch
  20. import torch.nn.functional as F
  21. import numpy as np
  22. class PositionalEncoding(torch.nn.Module):
  23. """Positional encoding.
  24. :param int d_model: embedding dim
  25. :param float dropout_rate: dropout rate
  26. :param int max_len: maximum input length
  27. PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
  28. PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
  29. """
  30. def __init__(self,
  31. d_model: int,
  32. dropout_rate: float,
  33. max_len: int = 5000,
  34. reverse: bool = False):
  35. """Construct an PositionalEncoding object."""
  36. super().__init__()
  37. self.d_model = d_model
  38. self.xscale = math.sqrt(self.d_model)
  39. self.dropout = torch.nn.Dropout(p=dropout_rate)
  40. self.max_len = max_len
  41. self.pe = torch.zeros(self.max_len, self.d_model)
  42. position = torch.arange(0, self.max_len,
  43. dtype=torch.float32).unsqueeze(1)
  44. div_term = torch.exp(
  45. torch.arange(0, self.d_model, 2, dtype=torch.float32) *
  46. -(math.log(10000.0) / self.d_model))
  47. self.pe[:, 0::2] = torch.sin(position * div_term)
  48. self.pe[:, 1::2] = torch.cos(position * div_term)
  49. self.pe = self.pe.unsqueeze(0)
  50. def forward(self,
  51. x: torch.Tensor,
  52. offset: Union[int, torch.Tensor] = 0) \
  53. -> Tuple[torch.Tensor, torch.Tensor]:
  54. """Add positional encoding.
  55. Args:
  56. x (torch.Tensor): Input. Its shape is (batch, time, ...)
  57. offset (int, torch.tensor): position offset
  58. Returns:
  59. torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
  60. torch.Tensor: for compatibility to RelPositionalEncoding
  61. """
  62. self.pe = self.pe.to(x.device)
  63. pos_emb = self.position_encoding(offset, x.size(1), False)
  64. x = x * self.xscale + pos_emb
  65. return self.dropout(x), self.dropout(pos_emb)
  66. def position_encoding(self,
  67. offset: Union[int, torch.Tensor],
  68. size: int,
  69. apply_dropout: bool = True) -> torch.Tensor:
  70. """ For getting encoding in a streaming fashion
  71. Attention!!!!!
  72. we apply dropout only once at the whole utterance level in a none
  73. streaming way, but will call this function several times with
  74. increasing input size in a streaming scenario, so the dropout will
  75. be applied several times.
  76. Args:
  77. offset (int or torch.tensor): start offset
  78. size (int): required size of position encoding
  79. Returns:
  80. torch.Tensor: Corresponding encoding
  81. """
  82. # How to subscript a Union type:
  83. # https://github.com/pytorch/pytorch/issues/69434
  84. if isinstance(offset, int):
  85. assert offset + size <= self.max_len
  86. pos_emb = self.pe[:, offset:offset + size]
  87. elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
  88. assert offset + size <= self.max_len
  89. pos_emb = self.pe[:, offset:offset + size]
  90. else: # for batched streaming decoding on GPU
  91. assert torch.max(offset) + size <= self.max_len
  92. index = offset.unsqueeze(1) + \
  93. torch.arange(0, size).to(offset.device) # B X T
  94. flag = index > 0
  95. # remove negative offset
  96. index = index * flag
  97. pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
  98. if apply_dropout:
  99. pos_emb = self.dropout(pos_emb)
  100. return pos_emb
  101. class RelPositionalEncoding(PositionalEncoding):
  102. """Relative positional encoding module.
  103. See : Appendix B in https://arxiv.org/abs/1901.02860
  104. Args:
  105. d_model (int): Embedding dimension.
  106. dropout_rate (float): Dropout rate.
  107. max_len (int): Maximum input length.
  108. """
  109. def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
  110. """Initialize class."""
  111. super().__init__(d_model, dropout_rate, max_len, reverse=True)
  112. def forward(self,
  113. x: torch.Tensor,
  114. offset: Union[int, torch.Tensor] = 0) \
  115. -> Tuple[torch.Tensor, torch.Tensor]:
  116. """Compute positional encoding.
  117. Args:
  118. x (torch.Tensor): Input tensor (batch, time, `*`).
  119. Returns:
  120. torch.Tensor: Encoded tensor (batch, time, `*`).
  121. torch.Tensor: Positional embedding tensor (1, time, `*`).
  122. """
  123. self.pe = self.pe.to(x.device)
  124. x = x * self.xscale
  125. pos_emb = self.position_encoding(offset, x.size(1), False)
  126. return self.dropout(x), self.dropout(pos_emb)
  127. class WhisperPositionalEncoding(PositionalEncoding):
  128. """ Sinusoids position encoding used in openai-whisper.encoder
  129. """
  130. def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
  131. super().__init__(d_model, dropout_rate, max_len)
  132. self.xscale = 1.0
  133. log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
  134. inv_timescales = torch.exp(-log_timescale_increment *
  135. torch.arange(d_model // 2))
  136. scaled_time = torch.arange(max_len)[:, np.newaxis] * \
  137. inv_timescales[np.newaxis, :]
  138. pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
  139. delattr(self, "pe")
  140. self.register_buffer("pe", pe.unsqueeze(0))
  141. class LearnablePositionalEncoding(PositionalEncoding):
  142. """ Learnable position encoding used in openai-whisper.decoder
  143. """
  144. def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
  145. super().__init__(d_model, dropout_rate, max_len)
  146. # NOTE(xcsong): overwrite self.pe & self.xscale
  147. self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
  148. self.xscale = 1.0
  149. class NoPositionalEncoding(torch.nn.Module):
  150. """ No position encoding
  151. """
  152. def __init__(self, d_model: int, dropout_rate: float):
  153. super().__init__()
  154. self.d_model = d_model
  155. self.dropout = torch.nn.Dropout(p=dropout_rate)
  156. def forward(self,
  157. x: torch.Tensor,
  158. offset: Union[int, torch.Tensor] = 0) \
  159. -> Tuple[torch.Tensor, torch.Tensor]:
  160. """ Just return zero vector for interface compatibility
  161. """
  162. pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
  163. return self.dropout(x), pos_emb
  164. def position_encoding(self, offset: Union[int, torch.Tensor],
  165. size: int) -> torch.Tensor:
  166. return torch.zeros(1, size, self.d_model)
  167. class EspnetRelPositionalEncoding(torch.nn.Module):
  168. """Relative positional encoding module (new implementation).
  169. Details can be found in https://github.com/espnet/espnet/pull/2816.
  170. See : Appendix B in https://arxiv.org/abs/1901.02860
  171. Args:
  172. d_model (int): Embedding dimension.
  173. dropout_rate (float): Dropout rate.
  174. max_len (int): Maximum input length.
  175. """
  176. def __init__(self, d_model, dropout_rate, max_len=5000):
  177. """Construct an PositionalEncoding object."""
  178. super(EspnetRelPositionalEncoding, self).__init__()
  179. self.d_model = d_model
  180. self.xscale = math.sqrt(self.d_model)
  181. self.dropout = torch.nn.Dropout(p=dropout_rate)
  182. self.pe = None
  183. self.extend_pe(torch.tensor(0.0).expand(1, max_len))
  184. def extend_pe(self, x):
  185. """Reset the positional encodings."""
  186. if self.pe is not None:
  187. # self.pe contains both positive and negative parts
  188. # the length of self.pe is 2 * input_len - 1
  189. if self.pe.size(1) >= x.size(1) * 2 - 1:
  190. if self.pe.dtype != x.dtype or self.pe.device != x.device:
  191. self.pe = self.pe.to(dtype=x.dtype, device=x.device)
  192. return
  193. # Suppose `i` means to the position of query vecotr and `j` means the
  194. # position of key vector. We use position relative positions when keys
  195. # are to the left (i>j) and negative relative positions otherwise (i<j).
  196. pe_positive = torch.zeros(x.size(1), self.d_model)
  197. pe_negative = torch.zeros(x.size(1), self.d_model)
  198. position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
  199. div_term = torch.exp(
  200. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  201. * -(math.log(10000.0) / self.d_model)
  202. )
  203. pe_positive[:, 0::2] = torch.sin(position * div_term)
  204. pe_positive[:, 1::2] = torch.cos(position * div_term)
  205. pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
  206. pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
  207. # Reserve the order of positive indices and concat both positive and
  208. # negative indices. This is used to support the shifting trick
  209. # as in https://arxiv.org/abs/1901.02860
  210. pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
  211. pe_negative = pe_negative[1:].unsqueeze(0)
  212. pe = torch.cat([pe_positive, pe_negative], dim=1)
  213. self.pe = pe.to(device=x.device, dtype=x.dtype)
  214. def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0):
  215. """Add positional encoding.
  216. Args:
  217. x (torch.Tensor): Input tensor (batch, time, `*`).
  218. Returns:
  219. torch.Tensor: Encoded tensor (batch, time, `*`).
  220. """
  221. self.extend_pe(x)
  222. x = x * self.xscale
  223. pos_emb = self.position_encoding(size=x.size(1), offset=offset)
  224. return self.dropout(x), self.dropout(pos_emb)
  225. def position_encoding(self,
  226. offset: Union[int, torch.Tensor],
  227. size: int) -> torch.Tensor:
  228. """ For getting encoding in a streaming fashion
  229. Attention!!!!!
  230. we apply dropout only once at the whole utterance level in a none
  231. streaming way, but will call this function several times with
  232. increasing input size in a streaming scenario, so the dropout will
  233. be applied several times.
  234. Args:
  235. offset (int or torch.tensor): start offset
  236. size (int): required size of position encoding
  237. Returns:
  238. torch.Tensor: Corresponding encoding
  239. """
  240. pos_emb = self.pe[
  241. :,
  242. self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
  243. ]
  244. return pos_emb