decoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from einops import pack, rearrange, repeat
  18. from cosyvoice.utils.common import mask_to_bias
  19. from cosyvoice.utils.mask import add_optional_chunk_mask
  20. from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
  21. from matcha.models.components.transformer import BasicTransformerBlock
  22. class Transpose(torch.nn.Module):
  23. def __init__(self, dim0: int, dim1: int):
  24. super().__init__()
  25. self.dim0 = dim0
  26. self.dim1 = dim1
  27. def forward(self, x: torch.Tensor):
  28. x = torch.transpose(x, self.dim0, self.dim1)
  29. return x
  30. class CausalBlock1D(Block1D):
  31. def __init__(self, dim: int, dim_out: int):
  32. super(CausalBlock1D, self).__init__(dim, dim_out)
  33. self.block = torch.nn.Sequential(
  34. CausalConv1d(dim, dim_out, 3),
  35. Transpose(1, 2),
  36. nn.LayerNorm(dim_out),
  37. Transpose(1, 2),
  38. nn.Mish(),
  39. )
  40. def forward(self, x: torch.Tensor, mask: torch.Tensor):
  41. output = self.block(x * mask)
  42. return output * mask
  43. class CausalResnetBlock1D(ResnetBlock1D):
  44. def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
  45. super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
  46. self.block1 = CausalBlock1D(dim, dim_out)
  47. self.block2 = CausalBlock1D(dim_out, dim_out)
  48. class CausalConv1d(torch.nn.Conv1d):
  49. def __init__(
  50. self,
  51. in_channels: int,
  52. out_channels: int,
  53. kernel_size: int,
  54. stride: int = 1,
  55. dilation: int = 1,
  56. groups: int = 1,
  57. bias: bool = True,
  58. padding_mode: str = 'zeros',
  59. device=None,
  60. dtype=None
  61. ) -> None:
  62. super(CausalConv1d, self).__init__(in_channels, out_channels,
  63. kernel_size, stride,
  64. padding=0, dilation=dilation,
  65. groups=groups, bias=bias,
  66. padding_mode=padding_mode,
  67. device=device, dtype=dtype)
  68. assert stride == 1
  69. self.causal_padding = (kernel_size - 1, 0)
  70. def forward(self, x: torch.Tensor):
  71. x = F.pad(x, self.causal_padding)
  72. x = super(CausalConv1d, self).forward(x)
  73. return x
  74. class ConditionalDecoder(nn.Module):
  75. def __init__(
  76. self,
  77. in_channels,
  78. out_channels,
  79. causal=False,
  80. channels=(256, 256),
  81. dropout=0.05,
  82. attention_head_dim=64,
  83. n_blocks=1,
  84. num_mid_blocks=2,
  85. num_heads=4,
  86. act_fn="snake",
  87. ):
  88. """
  89. This decoder requires an input with the same shape of the target. So, if your text content
  90. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  91. """
  92. super().__init__()
  93. channels = tuple(channels)
  94. self.in_channels = in_channels
  95. self.out_channels = out_channels
  96. self.causal = causal
  97. self.time_embeddings = SinusoidalPosEmb(in_channels)
  98. time_embed_dim = channels[0] * 4
  99. self.time_mlp = TimestepEmbedding(
  100. in_channels=in_channels,
  101. time_embed_dim=time_embed_dim,
  102. act_fn="silu",
  103. )
  104. self.down_blocks = nn.ModuleList([])
  105. self.mid_blocks = nn.ModuleList([])
  106. self.up_blocks = nn.ModuleList([])
  107. output_channel = in_channels
  108. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  109. input_channel = output_channel
  110. output_channel = channels[i]
  111. is_last = i == len(channels) - 1
  112. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
  113. ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  114. transformer_blocks = nn.ModuleList(
  115. [
  116. BasicTransformerBlock(
  117. dim=output_channel,
  118. num_attention_heads=num_heads,
  119. attention_head_dim=attention_head_dim,
  120. dropout=dropout,
  121. activation_fn=act_fn,
  122. )
  123. for _ in range(n_blocks)
  124. ]
  125. )
  126. downsample = (
  127. Downsample1D(output_channel) if not is_last else
  128. CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  129. )
  130. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  131. for _ in range(num_mid_blocks):
  132. input_channel = channels[-1]
  133. out_channels = channels[-1]
  134. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
  135. ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  136. transformer_blocks = nn.ModuleList(
  137. [
  138. BasicTransformerBlock(
  139. dim=output_channel,
  140. num_attention_heads=num_heads,
  141. attention_head_dim=attention_head_dim,
  142. dropout=dropout,
  143. activation_fn=act_fn,
  144. )
  145. for _ in range(n_blocks)
  146. ]
  147. )
  148. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  149. channels = channels[::-1] + (channels[0],)
  150. for i in range(len(channels) - 1):
  151. input_channel = channels[i] * 2
  152. output_channel = channels[i + 1]
  153. is_last = i == len(channels) - 2
  154. resnet = CausalResnetBlock1D(
  155. dim=input_channel,
  156. dim_out=output_channel,
  157. time_emb_dim=time_embed_dim,
  158. ) if self.causal else ResnetBlock1D(
  159. dim=input_channel,
  160. dim_out=output_channel,
  161. time_emb_dim=time_embed_dim,
  162. )
  163. transformer_blocks = nn.ModuleList(
  164. [
  165. BasicTransformerBlock(
  166. dim=output_channel,
  167. num_attention_heads=num_heads,
  168. attention_head_dim=attention_head_dim,
  169. dropout=dropout,
  170. activation_fn=act_fn,
  171. )
  172. for _ in range(n_blocks)
  173. ]
  174. )
  175. upsample = (
  176. Upsample1D(output_channel, use_conv_transpose=True)
  177. if not is_last
  178. else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  179. )
  180. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  181. self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
  182. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  183. self.initialize_weights()
  184. def initialize_weights(self):
  185. for m in self.modules():
  186. if isinstance(m, nn.Conv1d):
  187. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  188. if m.bias is not None:
  189. nn.init.constant_(m.bias, 0)
  190. elif isinstance(m, nn.GroupNorm):
  191. nn.init.constant_(m.weight, 1)
  192. nn.init.constant_(m.bias, 0)
  193. elif isinstance(m, nn.Linear):
  194. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  195. if m.bias is not None:
  196. nn.init.constant_(m.bias, 0)
  197. def forward(self, x, mask, mu, t, spks=None, cond=None):
  198. """Forward pass of the UNet1DConditional model.
  199. Args:
  200. x (torch.Tensor): shape (batch_size, in_channels, time)
  201. mask (_type_): shape (batch_size, 1, time)
  202. t (_type_): shape (batch_size)
  203. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  204. cond (_type_, optional): placeholder for future use. Defaults to None.
  205. Raises:
  206. ValueError: _description_
  207. ValueError: _description_
  208. Returns:
  209. _type_: _description_
  210. """
  211. t = self.time_embeddings(t).to(t.dtype)
  212. t = self.time_mlp(t)
  213. x = pack([x, mu], "b * t")[0]
  214. if spks is not None:
  215. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  216. x = pack([x, spks], "b * t")[0]
  217. if cond is not None:
  218. x = pack([x, cond], "b * t")[0]
  219. hiddens = []
  220. masks = [mask]
  221. for resnet, transformer_blocks, downsample in self.down_blocks:
  222. mask_down = masks[-1]
  223. x = resnet(x, mask_down, t)
  224. x = rearrange(x, "b c t -> b t c").contiguous()
  225. # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
  226. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
  227. attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
  228. for transformer_block in transformer_blocks:
  229. x = transformer_block(
  230. hidden_states=x,
  231. attention_mask=attn_mask,
  232. timestep=t,
  233. )
  234. x = rearrange(x, "b t c -> b c t").contiguous()
  235. hiddens.append(x) # Save hidden states for skip connections
  236. x = downsample(x * mask_down)
  237. masks.append(mask_down[:, :, ::2])
  238. masks = masks[:-1]
  239. mask_mid = masks[-1]
  240. for resnet, transformer_blocks in self.mid_blocks:
  241. x = resnet(x, mask_mid, t)
  242. x = rearrange(x, "b c t -> b t c").contiguous()
  243. # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
  244. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
  245. attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
  246. for transformer_block in transformer_blocks:
  247. x = transformer_block(
  248. hidden_states=x,
  249. attention_mask=attn_mask,
  250. timestep=t,
  251. )
  252. x = rearrange(x, "b t c -> b c t").contiguous()
  253. for resnet, transformer_blocks, upsample in self.up_blocks:
  254. mask_up = masks.pop()
  255. skip = hiddens.pop()
  256. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  257. x = resnet(x, mask_up, t)
  258. x = rearrange(x, "b c t -> b t c").contiguous()
  259. # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
  260. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
  261. attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
  262. for transformer_block in transformer_blocks:
  263. x = transformer_block(
  264. hidden_states=x,
  265. attention_mask=attn_mask,
  266. timestep=t,
  267. )
  268. x = rearrange(x, "b t c -> b c t").contiguous()
  269. x = upsample(x * mask_up)
  270. x = self.final_block(x, mask_up)
  271. output = self.final_proj(x * mask_up)
  272. return output * mask