decoder.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. from einops import pack, rearrange, repeat
  17. from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
  18. from matcha.models.components.transformer import BasicTransformerBlock
  19. class ConditionalDecoder(nn.Module):
  20. def __init__(
  21. self,
  22. in_channels,
  23. out_channels,
  24. channels=(256, 256),
  25. dropout=0.05,
  26. attention_head_dim=64,
  27. n_blocks=1,
  28. num_mid_blocks=2,
  29. num_heads=4,
  30. act_fn="snake",
  31. ):
  32. """
  33. This decoder requires an input with the same shape of the target. So, if your text content
  34. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  35. """
  36. super().__init__()
  37. channels = tuple(channels)
  38. self.in_channels = in_channels
  39. self.out_channels = out_channels
  40. self.time_embeddings = SinusoidalPosEmb(in_channels)
  41. time_embed_dim = channels[0] * 4
  42. self.time_mlp = TimestepEmbedding(
  43. in_channels=in_channels,
  44. time_embed_dim=time_embed_dim,
  45. act_fn="silu",
  46. )
  47. self.down_blocks = nn.ModuleList([])
  48. self.mid_blocks = nn.ModuleList([])
  49. self.up_blocks = nn.ModuleList([])
  50. output_channel = in_channels
  51. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  52. input_channel = output_channel
  53. output_channel = channels[i]
  54. is_last = i == len(channels) - 1
  55. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  56. transformer_blocks = nn.ModuleList(
  57. [
  58. BasicTransformerBlock(
  59. dim=output_channel,
  60. num_attention_heads=num_heads,
  61. attention_head_dim=attention_head_dim,
  62. dropout=dropout,
  63. activation_fn=act_fn,
  64. )
  65. for _ in range(n_blocks)
  66. ]
  67. )
  68. downsample = (
  69. Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  70. )
  71. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  72. for i in range(num_mid_blocks):
  73. input_channel = channels[-1]
  74. out_channels = channels[-1]
  75. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  76. transformer_blocks = nn.ModuleList(
  77. [
  78. BasicTransformerBlock(
  79. dim=output_channel,
  80. num_attention_heads=num_heads,
  81. attention_head_dim=attention_head_dim,
  82. dropout=dropout,
  83. activation_fn=act_fn,
  84. )
  85. for _ in range(n_blocks)
  86. ]
  87. )
  88. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  89. channels = channels[::-1] + (channels[0],)
  90. for i in range(len(channels) - 1):
  91. input_channel = channels[i] * 2
  92. output_channel = channels[i + 1]
  93. is_last = i == len(channels) - 2
  94. resnet = ResnetBlock1D(
  95. dim=input_channel,
  96. dim_out=output_channel,
  97. time_emb_dim=time_embed_dim,
  98. )
  99. transformer_blocks = nn.ModuleList(
  100. [
  101. BasicTransformerBlock(
  102. dim=output_channel,
  103. num_attention_heads=num_heads,
  104. attention_head_dim=attention_head_dim,
  105. dropout=dropout,
  106. activation_fn=act_fn,
  107. )
  108. for _ in range(n_blocks)
  109. ]
  110. )
  111. upsample = (
  112. Upsample1D(output_channel, use_conv_transpose=True)
  113. if not is_last
  114. else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  115. )
  116. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  117. self.final_block = Block1D(channels[-1], channels[-1])
  118. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  119. self.initialize_weights()
  120. def initialize_weights(self):
  121. for m in self.modules():
  122. if isinstance(m, nn.Conv1d):
  123. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  124. if m.bias is not None:
  125. nn.init.constant_(m.bias, 0)
  126. elif isinstance(m, nn.GroupNorm):
  127. nn.init.constant_(m.weight, 1)
  128. nn.init.constant_(m.bias, 0)
  129. elif isinstance(m, nn.Linear):
  130. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  131. if m.bias is not None:
  132. nn.init.constant_(m.bias, 0)
  133. def forward(self, x, mask, mu, t, spks=None, cond=None):
  134. """Forward pass of the UNet1DConditional model.
  135. Args:
  136. x (torch.Tensor): shape (batch_size, in_channels, time)
  137. mask (_type_): shape (batch_size, 1, time)
  138. t (_type_): shape (batch_size)
  139. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  140. cond (_type_, optional): placeholder for future use. Defaults to None.
  141. Raises:
  142. ValueError: _description_
  143. ValueError: _description_
  144. Returns:
  145. _type_: _description_
  146. """
  147. t = self.time_embeddings(t)
  148. t = self.time_mlp(t)
  149. x = pack([x, mu], "b * t")[0]
  150. if spks is not None:
  151. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  152. x = pack([x, spks], "b * t")[0]
  153. if cond is not None:
  154. x = pack([x, cond], "b * t")[0]
  155. hiddens = []
  156. masks = [mask]
  157. for resnet, transformer_blocks, downsample in self.down_blocks:
  158. mask_down = masks[-1]
  159. x = resnet(x, mask_down, t)
  160. x = rearrange(x, "b c t -> b t c").contiguous()
  161. attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
  162. for transformer_block in transformer_blocks:
  163. x = transformer_block(
  164. hidden_states=x,
  165. attention_mask=attn_mask,
  166. timestep=t,
  167. )
  168. x = rearrange(x, "b t c -> b c t").contiguous()
  169. hiddens.append(x) # Save hidden states for skip connections
  170. x = downsample(x * mask_down)
  171. masks.append(mask_down[:, :, ::2])
  172. masks = masks[:-1]
  173. mask_mid = masks[-1]
  174. for resnet, transformer_blocks in self.mid_blocks:
  175. x = resnet(x, mask_mid, t)
  176. x = rearrange(x, "b c t -> b t c").contiguous()
  177. attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
  178. for transformer_block in transformer_blocks:
  179. x = transformer_block(
  180. hidden_states=x,
  181. attention_mask=attn_mask,
  182. timestep=t,
  183. )
  184. x = rearrange(x, "b t c -> b c t").contiguous()
  185. for resnet, transformer_blocks, upsample in self.up_blocks:
  186. mask_up = masks.pop()
  187. skip = hiddens.pop()
  188. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  189. x = resnet(x, mask_up, t)
  190. x = rearrange(x, "b c t -> b t c").contiguous()
  191. attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
  192. for transformer_block in transformer_blocks:
  193. x = transformer_block(
  194. hidden_states=x,
  195. attention_mask=attn_mask,
  196. timestep=t,
  197. )
  198. x = rearrange(x, "b t c -> b c t").contiguous()
  199. x = upsample(x * mask_up)
  200. x = self.final_block(x, mask_up)
  201. output = self.final_proj(x * mask_up)
  202. return output * mask