convolution.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
  2. # 2024 Alibaba Inc (Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Modified from ESPnet(https://github.com/espnet/espnet)
  16. """ConvolutionModule definition."""
  17. from typing import Tuple
  18. import torch
  19. from torch import nn
  20. import torch.nn.functional as F
  21. class ConvolutionModule(nn.Module):
  22. """ConvolutionModule in Conformer model."""
  23. def __init__(self,
  24. channels: int,
  25. kernel_size: int = 15,
  26. activation: nn.Module = nn.ReLU(),
  27. norm: str = "batch_norm",
  28. causal: bool = False,
  29. bias: bool = True):
  30. """Construct an ConvolutionModule object.
  31. Args:
  32. channels (int): The number of channels of conv layers.
  33. kernel_size (int): Kernel size of conv layers.
  34. causal (int): Whether use causal convolution or not
  35. """
  36. super().__init__()
  37. self.pointwise_conv1 = nn.Conv1d(
  38. channels,
  39. 2 * channels,
  40. kernel_size=1,
  41. stride=1,
  42. padding=0,
  43. bias=bias,
  44. )
  45. # self.lorder is used to distinguish if it's a causal convolution,
  46. # if self.lorder > 0: it's a causal convolution, the input will be
  47. # padded with self.lorder frames on the left in forward.
  48. # else: it's a symmetrical convolution
  49. if causal:
  50. padding = 0
  51. self.lorder = kernel_size - 1
  52. else:
  53. # kernel_size should be an odd number for none causal convolution
  54. assert (kernel_size - 1) % 2 == 0
  55. padding = (kernel_size - 1) // 2
  56. self.lorder = 0
  57. self.depthwise_conv = nn.Conv1d(
  58. channels,
  59. channels,
  60. kernel_size,
  61. stride=1,
  62. padding=padding,
  63. groups=channels,
  64. bias=bias,
  65. )
  66. assert norm in ['batch_norm', 'layer_norm']
  67. if norm == "batch_norm":
  68. self.use_layer_norm = False
  69. self.norm = nn.BatchNorm1d(channels)
  70. else:
  71. self.use_layer_norm = True
  72. self.norm = nn.LayerNorm(channels)
  73. self.pointwise_conv2 = nn.Conv1d(
  74. channels,
  75. channels,
  76. kernel_size=1,
  77. stride=1,
  78. padding=0,
  79. bias=bias,
  80. )
  81. self.activation = activation
  82. def forward(
  83. self,
  84. x: torch.Tensor,
  85. mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  86. cache: torch.Tensor = torch.zeros((0, 0, 0)),
  87. ) -> Tuple[torch.Tensor, torch.Tensor]:
  88. """Compute convolution module.
  89. Args:
  90. x (torch.Tensor): Input tensor (#batch, time, channels).
  91. mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
  92. (0, 0, 0) means fake mask.
  93. cache (torch.Tensor): left context cache, it is only
  94. used in causal convolution (#batch, channels, cache_t),
  95. (0, 0, 0) meas fake cache.
  96. Returns:
  97. torch.Tensor: Output tensor (#batch, time, channels).
  98. """
  99. # exchange the temporal dimension and the feature dimension
  100. x = x.transpose(1, 2) # (#batch, channels, time)
  101. # mask batch padding
  102. if mask_pad.size(2) > 0: # time > 0
  103. x.masked_fill_(~mask_pad, 0.0)
  104. if self.lorder > 0:
  105. if cache.size(2) == 0: # cache_t == 0
  106. x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
  107. else:
  108. assert cache.size(0) == x.size(0) # equal batch
  109. assert cache.size(1) == x.size(1) # equal channel
  110. x = torch.cat((cache, x), dim=2)
  111. assert (x.size(2) > self.lorder)
  112. new_cache = x[:, :, -self.lorder:]
  113. else:
  114. # It's better we just return None if no cache is required,
  115. # However, for JIT export, here we just fake one tensor instead of
  116. # None.
  117. new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
  118. # GLU mechanism
  119. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  120. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  121. # 1D Depthwise Conv
  122. x = self.depthwise_conv(x)
  123. if self.use_layer_norm:
  124. x = x.transpose(1, 2)
  125. x = self.activation(self.norm(x))
  126. if self.use_layer_norm:
  127. x = x.transpose(1, 2)
  128. x = self.pointwise_conv2(x)
  129. # mask batch padding
  130. if mask_pad.size(2) > 0: # time > 0
  131. x.masked_fill_(~mask_pad, 0.0)
  132. return x.transpose(1, 2), new_cache
  133. # NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
  134. class CausalConv1d(torch.nn.Conv1d):
  135. def __init__(
  136. self,
  137. in_channels: int,
  138. out_channels: int,
  139. kernel_size: int,
  140. stride: int = 1,
  141. dilation: int = 1,
  142. groups: int = 1,
  143. bias: bool = True,
  144. padding_mode: str = 'zeros',
  145. causal_type: str = 'left',
  146. device=None,
  147. dtype=None
  148. ) -> None:
  149. super(CausalConv1d, self).__init__(in_channels, out_channels,
  150. kernel_size, stride=1,
  151. padding=0, dilation=dilation,
  152. groups=groups, bias=bias,
  153. padding_mode=padding_mode,
  154. device=device, dtype=dtype)
  155. assert stride == 1
  156. self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
  157. assert causal_type in ['left', 'right']
  158. self.causal_type = causal_type
  159. def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
  160. input_timestep = x.shape[2]
  161. if cache.size(2) == 0:
  162. cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
  163. assert cache.size(2) == self.causal_padding
  164. if self.causal_type == 'left':
  165. x = torch.concat([cache, x], dim=2)
  166. else:
  167. x = torch.concat([x, cache], dim=2)
  168. x = super(CausalConv1d, self).forward(x)
  169. assert x.shape[2] == input_timestep
  170. return x
  171. class CausalConv1dDownSample(torch.nn.Conv1d):
  172. def __init__(
  173. self,
  174. in_channels: int,
  175. out_channels: int,
  176. kernel_size: int,
  177. stride: int = 1,
  178. dilation: int = 1,
  179. groups: int = 1,
  180. bias: bool = True,
  181. padding_mode: str = 'zeros',
  182. device=None,
  183. dtype=None
  184. ) -> None:
  185. super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
  186. kernel_size, stride,
  187. padding=0, dilation=dilation,
  188. groups=groups, bias=bias,
  189. padding_mode=padding_mode,
  190. device=device, dtype=dtype)
  191. assert stride != 1 and dilation == 1
  192. assert kernel_size % stride == 0
  193. self.causal_padding = stride - 1
  194. def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  195. if cache.size(2) == 0:
  196. x = F.pad(x, (self.causal_padding, 0), value=0.0)
  197. else:
  198. assert cache.size(2) == self.causal_padding
  199. x = torch.concat([cache, x], dim=2)
  200. x = super(CausalConv1dDownSample, self).forward(x)
  201. return x
  202. class CausalConv1dUpsample(torch.nn.Conv1d):
  203. def __init__(
  204. self,
  205. in_channels: int,
  206. out_channels: int,
  207. kernel_size: int,
  208. stride: int = 1,
  209. dilation: int = 1,
  210. groups: int = 1,
  211. bias: bool = True,
  212. padding_mode: str = 'zeros',
  213. device=None,
  214. dtype=None
  215. ) -> None:
  216. super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
  217. kernel_size, 1,
  218. padding=0, dilation=dilation,
  219. groups=groups, bias=bias,
  220. padding_mode=padding_mode,
  221. device=device, dtype=dtype)
  222. assert dilation == 1
  223. self.causal_padding = kernel_size - 1
  224. self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
  225. def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  226. x = self.upsample(x)
  227. input_timestep = x.shape[2]
  228. if cache.size(2) == 0:
  229. x = F.pad(x, (self.causal_padding, 0), value=0.0)
  230. else:
  231. assert cache.size(2) == self.causal_padding
  232. x = torch.concat([cache, x], dim=2)
  233. x = super(CausalConv1dUpsample, self).forward(x)
  234. assert input_timestep == x.shape[2]
  235. return x