subsampling.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. # 2024 Alibaba Inc (Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Modified from ESPnet(https://github.com/espnet/espnet)
  16. """Subsampling layer definition."""
  17. from typing import Tuple, Union
  18. import torch
  19. class BaseSubsampling(torch.nn.Module):
  20. def __init__(self):
  21. super().__init__()
  22. self.right_context = 0
  23. self.subsampling_rate = 1
  24. def position_encoding(self, offset: Union[int, torch.Tensor],
  25. size: int) -> torch.Tensor:
  26. return self.pos_enc.position_encoding(offset, size)
  27. class EmbedinigNoSubsampling(BaseSubsampling):
  28. """Embedding input without subsampling
  29. """
  30. def __init__(self, idim: int, odim: int, dropout_rate: float,
  31. pos_enc_class: torch.nn.Module):
  32. super().__init__()
  33. self.embed = torch.nn.Embedding(idim, odim)
  34. self.pos_enc = pos_enc_class
  35. def forward(
  36. self,
  37. x: torch.Tensor,
  38. x_mask: torch.Tensor,
  39. offset: Union[int, torch.Tensor] = 0
  40. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  41. """Input x.
  42. Args:
  43. x (torch.Tensor): Input tensor (#batch, time, idim).
  44. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  45. Returns:
  46. torch.Tensor: linear input tensor (#batch, time', odim),
  47. where time' = time .
  48. torch.Tensor: linear input mask (#batch, 1, time'),
  49. where time' = time .
  50. """
  51. x = self.embed(x)
  52. x, pos_emb = self.pos_enc(x, offset)
  53. return x, pos_emb, x_mask
  54. class LinearNoSubsampling(BaseSubsampling):
  55. """Linear transform the input without subsampling
  56. Args:
  57. idim (int): Input dimension.
  58. odim (int): Output dimension.
  59. dropout_rate (float): Dropout rate.
  60. """
  61. def __init__(self, idim: int, odim: int, dropout_rate: float,
  62. pos_enc_class: torch.nn.Module):
  63. """Construct an linear object."""
  64. super().__init__()
  65. self.out = torch.nn.Sequential(
  66. torch.nn.Linear(idim, odim),
  67. torch.nn.LayerNorm(odim, eps=1e-5),
  68. torch.nn.Dropout(dropout_rate),
  69. )
  70. self.pos_enc = pos_enc_class
  71. self.right_context = 0
  72. self.subsampling_rate = 1
  73. def forward(
  74. self,
  75. x: torch.Tensor,
  76. x_mask: torch.Tensor,
  77. offset: Union[int, torch.Tensor] = 0
  78. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  79. """Input x.
  80. Args:
  81. x (torch.Tensor): Input tensor (#batch, time, idim).
  82. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  83. Returns:
  84. torch.Tensor: linear input tensor (#batch, time', odim),
  85. where time' = time .
  86. torch.Tensor: linear input mask (#batch, 1, time'),
  87. where time' = time .
  88. """
  89. x = self.out(x)
  90. x, pos_emb = self.pos_enc(x, offset)
  91. return x, pos_emb, x_mask
  92. class Conv1dSubsampling2(BaseSubsampling):
  93. """Convolutional 1D subsampling (to 1/2 length).
  94. It is designed for Whisper, ref:
  95. https://github.com/openai/whisper/blob/main/whisper/model.py
  96. Args:
  97. idim (int): Input dimension.
  98. odim (int): Output dimension.
  99. dropout_rate (float): Dropout rate.
  100. """
  101. def __init__(self, idim: int, odim: int, dropout_rate: float,
  102. pos_enc_class: torch.nn.Module):
  103. """Construct an Conv1dSubsampling2 object."""
  104. super().__init__()
  105. self.conv = torch.nn.Sequential(
  106. torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
  107. torch.nn.GELU(),
  108. torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
  109. torch.nn.GELU(),
  110. )
  111. self.pos_enc = pos_enc_class
  112. # The right context for every conv layer is computed by:
  113. # (kernel_size - 1) * frame_rate_of_this_layer
  114. self.subsampling_rate = 2
  115. # 4 = (3 - 1) * 1 + (3 - 1) * 1
  116. self.right_context = 4
  117. def forward(
  118. self,
  119. x: torch.Tensor,
  120. x_mask: torch.Tensor,
  121. offset: Union[int, torch.Tensor] = 0
  122. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  123. """Subsample x.
  124. Args:
  125. x (torch.Tensor): Input tensor (#batch, time, idim).
  126. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  127. Returns:
  128. torch.Tensor: Subsampled tensor (#batch, time', odim),
  129. where time' = time // 2.
  130. torch.Tensor: Subsampled mask (#batch, 1, time'),
  131. where time' = time // 2.
  132. torch.Tensor: positional encoding
  133. """
  134. time = x.size(1)
  135. x = x.transpose(1, 2) # (b, f, t)
  136. x = self.conv(x)
  137. x = x.transpose(1, 2) # (b, t, f)
  138. x, pos_emb = self.pos_enc(x, offset)
  139. return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
  140. class Conv2dSubsampling4(BaseSubsampling):
  141. """Convolutional 2D subsampling (to 1/4 length).
  142. Args:
  143. idim (int): Input dimension.
  144. odim (int): Output dimension.
  145. dropout_rate (float): Dropout rate.
  146. """
  147. def __init__(self, idim: int, odim: int, dropout_rate: float,
  148. pos_enc_class: torch.nn.Module):
  149. """Construct an Conv2dSubsampling4 object."""
  150. super().__init__()
  151. self.conv = torch.nn.Sequential(
  152. torch.nn.Conv2d(1, odim, 3, 2),
  153. torch.nn.ReLU(),
  154. torch.nn.Conv2d(odim, odim, 3, 2),
  155. torch.nn.ReLU(),
  156. )
  157. self.out = torch.nn.Sequential(
  158. torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
  159. self.pos_enc = pos_enc_class
  160. # The right context for every conv layer is computed by:
  161. # (kernel_size - 1) * frame_rate_of_this_layer
  162. self.subsampling_rate = 4
  163. # 6 = (3 - 1) * 1 + (3 - 1) * 2
  164. self.right_context = 6
  165. def forward(
  166. self,
  167. x: torch.Tensor,
  168. x_mask: torch.Tensor,
  169. offset: Union[int, torch.Tensor] = 0
  170. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  171. """Subsample x.
  172. Args:
  173. x (torch.Tensor): Input tensor (#batch, time, idim).
  174. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  175. Returns:
  176. torch.Tensor: Subsampled tensor (#batch, time', odim),
  177. where time' = time // 4.
  178. torch.Tensor: Subsampled mask (#batch, 1, time'),
  179. where time' = time // 4.
  180. torch.Tensor: positional encoding
  181. """
  182. x = x.unsqueeze(1) # (b, c=1, t, f)
  183. x = self.conv(x)
  184. b, c, t, f = x.size()
  185. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  186. x, pos_emb = self.pos_enc(x, offset)
  187. return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
  188. class Conv2dSubsampling6(BaseSubsampling):
  189. """Convolutional 2D subsampling (to 1/6 length).
  190. Args:
  191. idim (int): Input dimension.
  192. odim (int): Output dimension.
  193. dropout_rate (float): Dropout rate.
  194. pos_enc (torch.nn.Module): Custom position encoding layer.
  195. """
  196. def __init__(self, idim: int, odim: int, dropout_rate: float,
  197. pos_enc_class: torch.nn.Module):
  198. """Construct an Conv2dSubsampling6 object."""
  199. super().__init__()
  200. self.conv = torch.nn.Sequential(
  201. torch.nn.Conv2d(1, odim, 3, 2),
  202. torch.nn.ReLU(),
  203. torch.nn.Conv2d(odim, odim, 5, 3),
  204. torch.nn.ReLU(),
  205. )
  206. self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
  207. odim)
  208. self.pos_enc = pos_enc_class
  209. # 10 = (3 - 1) * 1 + (5 - 1) * 2
  210. self.subsampling_rate = 6
  211. self.right_context = 10
  212. def forward(
  213. self,
  214. x: torch.Tensor,
  215. x_mask: torch.Tensor,
  216. offset: Union[int, torch.Tensor] = 0
  217. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  218. """Subsample x.
  219. Args:
  220. x (torch.Tensor): Input tensor (#batch, time, idim).
  221. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  222. Returns:
  223. torch.Tensor: Subsampled tensor (#batch, time', odim),
  224. where time' = time // 6.
  225. torch.Tensor: Subsampled mask (#batch, 1, time'),
  226. where time' = time // 6.
  227. torch.Tensor: positional encoding
  228. """
  229. x = x.unsqueeze(1) # (b, c, t, f)
  230. x = self.conv(x)
  231. b, c, t, f = x.size()
  232. x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
  233. x, pos_emb = self.pos_enc(x, offset)
  234. return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
  235. class Conv2dSubsampling8(BaseSubsampling):
  236. """Convolutional 2D subsampling (to 1/8 length).
  237. Args:
  238. idim (int): Input dimension.
  239. odim (int): Output dimension.
  240. dropout_rate (float): Dropout rate.
  241. """
  242. def __init__(self, idim: int, odim: int, dropout_rate: float,
  243. pos_enc_class: torch.nn.Module):
  244. """Construct an Conv2dSubsampling8 object."""
  245. super().__init__()
  246. self.conv = torch.nn.Sequential(
  247. torch.nn.Conv2d(1, odim, 3, 2),
  248. torch.nn.ReLU(),
  249. torch.nn.Conv2d(odim, odim, 3, 2),
  250. torch.nn.ReLU(),
  251. torch.nn.Conv2d(odim, odim, 3, 2),
  252. torch.nn.ReLU(),
  253. )
  254. self.linear = torch.nn.Linear(
  255. odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
  256. self.pos_enc = pos_enc_class
  257. self.subsampling_rate = 8
  258. # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
  259. self.right_context = 14
  260. def forward(
  261. self,
  262. x: torch.Tensor,
  263. x_mask: torch.Tensor,
  264. offset: Union[int, torch.Tensor] = 0
  265. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  266. """Subsample x.
  267. Args:
  268. x (torch.Tensor): Input tensor (#batch, time, idim).
  269. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  270. Returns:
  271. torch.Tensor: Subsampled tensor (#batch, time', odim),
  272. where time' = time // 8.
  273. torch.Tensor: Subsampled mask (#batch, 1, time'),
  274. where time' = time // 8.
  275. torch.Tensor: positional encoding
  276. """
  277. x = x.unsqueeze(1) # (b, c, t, f)
  278. x = self.conv(x)
  279. b, c, t, f = x.size()
  280. x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
  281. x, pos_emb = self.pos_enc(x, offset)
  282. return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
  283. class LegacyLinearNoSubsampling(BaseSubsampling):
  284. """Linear transform the input without subsampling
  285. Args:
  286. idim (int): Input dimension.
  287. odim (int): Output dimension.
  288. dropout_rate (float): Dropout rate.
  289. """
  290. def __init__(self, idim: int, odim: int, dropout_rate: float,
  291. pos_enc_class: torch.nn.Module):
  292. """Construct an linear object."""
  293. super().__init__()
  294. self.out = torch.nn.Sequential(
  295. torch.nn.Linear(idim, odim),
  296. torch.nn.LayerNorm(odim, eps=1e-5),
  297. torch.nn.Dropout(dropout_rate),
  298. torch.nn.ReLU(),
  299. )
  300. self.pos_enc = pos_enc_class
  301. self.right_context = 0
  302. self.subsampling_rate = 1
  303. def forward(
  304. self,
  305. x: torch.Tensor,
  306. x_mask: torch.Tensor,
  307. offset: Union[int, torch.Tensor] = 0
  308. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  309. """Input x.
  310. Args:
  311. x (torch.Tensor): Input tensor (#batch, time, idim).
  312. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  313. Returns:
  314. torch.Tensor: linear input tensor (#batch, time', odim),
  315. where time' = time .
  316. torch.Tensor: linear input mask (#batch, 1, time'),
  317. where time' = time .
  318. """
  319. x = self.out(x)
  320. x, pos_emb = self.pos_enc(x, offset)
  321. return x, pos_emb, x_mask