1
0

convolution.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
  2. # 2024 Alibaba Inc (Xiang Lyu)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Modified from ESPnet(https://github.com/espnet/espnet)
  16. """ConvolutionModule definition."""
  17. from typing import Tuple
  18. import torch
  19. from torch import nn
  20. class ConvolutionModule(nn.Module):
  21. """ConvolutionModule in Conformer model."""
  22. def __init__(self,
  23. channels: int,
  24. kernel_size: int = 15,
  25. activation: nn.Module = nn.ReLU(),
  26. norm: str = "batch_norm",
  27. causal: bool = False,
  28. bias: bool = True):
  29. """Construct an ConvolutionModule object.
  30. Args:
  31. channels (int): The number of channels of conv layers.
  32. kernel_size (int): Kernel size of conv layers.
  33. causal (int): Whether use causal convolution or not
  34. """
  35. super().__init__()
  36. self.pointwise_conv1 = nn.Conv1d(
  37. channels,
  38. 2 * channels,
  39. kernel_size=1,
  40. stride=1,
  41. padding=0,
  42. bias=bias,
  43. )
  44. # self.lorder is used to distinguish if it's a causal convolution,
  45. # if self.lorder > 0: it's a causal convolution, the input will be
  46. # padded with self.lorder frames on the left in forward.
  47. # else: it's a symmetrical convolution
  48. if causal:
  49. padding = 0
  50. self.lorder = kernel_size - 1
  51. else:
  52. # kernel_size should be an odd number for none causal convolution
  53. assert (kernel_size - 1) % 2 == 0
  54. padding = (kernel_size - 1) // 2
  55. self.lorder = 0
  56. self.depthwise_conv = nn.Conv1d(
  57. channels,
  58. channels,
  59. kernel_size,
  60. stride=1,
  61. padding=padding,
  62. groups=channels,
  63. bias=bias,
  64. )
  65. assert norm in ['batch_norm', 'layer_norm']
  66. if norm == "batch_norm":
  67. self.use_layer_norm = False
  68. self.norm = nn.BatchNorm1d(channels)
  69. else:
  70. self.use_layer_norm = True
  71. self.norm = nn.LayerNorm(channels)
  72. self.pointwise_conv2 = nn.Conv1d(
  73. channels,
  74. channels,
  75. kernel_size=1,
  76. stride=1,
  77. padding=0,
  78. bias=bias,
  79. )
  80. self.activation = activation
  81. def forward(
  82. self,
  83. x: torch.Tensor,
  84. mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
  85. cache: torch.Tensor = torch.zeros((0, 0, 0)),
  86. ) -> Tuple[torch.Tensor, torch.Tensor]:
  87. """Compute convolution module.
  88. Args:
  89. x (torch.Tensor): Input tensor (#batch, time, channels).
  90. mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
  91. (0, 0, 0) means fake mask.
  92. cache (torch.Tensor): left context cache, it is only
  93. used in causal convolution (#batch, channels, cache_t),
  94. (0, 0, 0) meas fake cache.
  95. Returns:
  96. torch.Tensor: Output tensor (#batch, time, channels).
  97. """
  98. # exchange the temporal dimension and the feature dimension
  99. x = x.transpose(1, 2) # (#batch, channels, time)
  100. # mask batch padding
  101. if mask_pad.size(2) > 0: # time > 0
  102. x.masked_fill_(~mask_pad, 0.0)
  103. if self.lorder > 0:
  104. if cache.size(2) == 0: # cache_t == 0
  105. x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
  106. else:
  107. assert cache.size(0) == x.size(0) # equal batch
  108. assert cache.size(1) == x.size(1) # equal channel
  109. x = torch.cat((cache, x), dim=2)
  110. assert (x.size(2) > self.lorder)
  111. new_cache = x[:, :, -self.lorder:]
  112. else:
  113. # It's better we just return None if no cache is required,
  114. # However, for JIT export, here we just fake one tensor instead of
  115. # None.
  116. new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
  117. # GLU mechanism
  118. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  119. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  120. # 1D Depthwise Conv
  121. x = self.depthwise_conv(x)
  122. if self.use_layer_norm:
  123. x = x.transpose(1, 2)
  124. x = self.activation(self.norm(x))
  125. if self.use_layer_norm:
  126. x = x.transpose(1, 2)
  127. x = self.pointwise_conv2(x)
  128. # mask batch padding
  129. if mask_pad.size(2) > 0: # time > 0
  130. x.masked_fill_(~mask_pad, 0.0)
  131. return x.transpose(1, 2), new_cache