1
0

decoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from einops import pack, rearrange, repeat
  18. from cosyvoice.utils.common import mask_to_bias
  19. from cosyvoice.utils.mask import add_optional_chunk_mask
  20. from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
  21. from matcha.models.components.transformer import BasicTransformerBlock
  22. class Transpose(torch.nn.Module):
  23. def __init__(self, dim0: int, dim1: int):
  24. super().__init__()
  25. self.dim0 = dim0
  26. self.dim1 = dim1
  27. def forward(self, x: torch.Tensor):
  28. x = torch.transpose(x, self.dim0, self.dim1)
  29. return x
  30. class CausalBlock1D(Block1D):
  31. def __init__(self, dim: int, dim_out: int):
  32. super(CausalBlock1D, self).__init__(dim, dim_out)
  33. self.block = torch.nn.Sequential(
  34. CausalConv1d(dim, dim_out, 3),
  35. Transpose(1, 2),
  36. nn.LayerNorm(dim_out),
  37. Transpose(1, 2),
  38. nn.Mish(),
  39. )
  40. def forward(self, x: torch.Tensor, mask: torch.Tensor):
  41. output = self.block(x * mask)
  42. return output * mask
  43. class CausalResnetBlock1D(ResnetBlock1D):
  44. def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
  45. super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
  46. self.block1 = CausalBlock1D(dim, dim_out)
  47. self.block2 = CausalBlock1D(dim_out, dim_out)
  48. class CausalConv1d(torch.nn.Conv1d):
  49. def __init__(
  50. self,
  51. in_channels: int,
  52. out_channels: int,
  53. kernel_size: int,
  54. stride: int = 1,
  55. dilation: int = 1,
  56. groups: int = 1,
  57. bias: bool = True,
  58. padding_mode: str = 'zeros',
  59. device=None,
  60. dtype=None
  61. ) -> None:
  62. super(CausalConv1d, self).__init__(in_channels, out_channels,
  63. kernel_size, stride,
  64. padding=0, dilation=dilation,
  65. groups=groups, bias=bias,
  66. padding_mode=padding_mode,
  67. device=device, dtype=dtype
  68. )
  69. assert stride == 1
  70. self.causal_padding = (kernel_size - 1, 0)
  71. def forward(self, x: torch.Tensor):
  72. x = F.pad(x, self.causal_padding)
  73. x = super(CausalConv1d, self).forward(x)
  74. return x
  75. class ConditionalDecoder(nn.Module):
  76. def __init__(
  77. self,
  78. in_channels,
  79. out_channels,
  80. causal=False,
  81. channels=(256, 256),
  82. dropout=0.05,
  83. attention_head_dim=64,
  84. n_blocks=1,
  85. num_mid_blocks=2,
  86. num_heads=4,
  87. act_fn="snake",
  88. ):
  89. """
  90. This decoder requires an input with the same shape of the target. So, if your text content
  91. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  92. """
  93. super().__init__()
  94. channels = tuple(channels)
  95. self.in_channels = in_channels
  96. self.out_channels = out_channels
  97. self.causal = causal
  98. self.time_embeddings = SinusoidalPosEmb(in_channels)
  99. time_embed_dim = channels[0] * 4
  100. self.time_mlp = TimestepEmbedding(
  101. in_channels=in_channels,
  102. time_embed_dim=time_embed_dim,
  103. act_fn="silu",
  104. )
  105. self.down_blocks = nn.ModuleList([])
  106. self.mid_blocks = nn.ModuleList([])
  107. self.up_blocks = nn.ModuleList([])
  108. output_channel = in_channels
  109. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  110. input_channel = output_channel
  111. output_channel = channels[i]
  112. is_last = i == len(channels) - 1
  113. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  114. transformer_blocks = nn.ModuleList(
  115. [
  116. BasicTransformerBlock(
  117. dim=output_channel,
  118. num_attention_heads=num_heads,
  119. attention_head_dim=attention_head_dim,
  120. dropout=dropout,
  121. activation_fn=act_fn,
  122. )
  123. for _ in range(n_blocks)
  124. ]
  125. )
  126. downsample = (
  127. Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  128. )
  129. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  130. for _ in range(num_mid_blocks):
  131. input_channel = channels[-1]
  132. out_channels = channels[-1]
  133. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  134. transformer_blocks = nn.ModuleList(
  135. [
  136. BasicTransformerBlock(
  137. dim=output_channel,
  138. num_attention_heads=num_heads,
  139. attention_head_dim=attention_head_dim,
  140. dropout=dropout,
  141. activation_fn=act_fn,
  142. )
  143. for _ in range(n_blocks)
  144. ]
  145. )
  146. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  147. channels = channels[::-1] + (channels[0],)
  148. for i in range(len(channels) - 1):
  149. input_channel = channels[i] * 2
  150. output_channel = channels[i + 1]
  151. is_last = i == len(channels) - 2
  152. resnet = CausalResnetBlock1D(
  153. dim=input_channel,
  154. dim_out=output_channel,
  155. time_emb_dim=time_embed_dim,
  156. ) if self.causal else ResnetBlock1D(
  157. dim=input_channel,
  158. dim_out=output_channel,
  159. time_emb_dim=time_embed_dim,
  160. )
  161. transformer_blocks = nn.ModuleList(
  162. [
  163. BasicTransformerBlock(
  164. dim=output_channel,
  165. num_attention_heads=num_heads,
  166. attention_head_dim=attention_head_dim,
  167. dropout=dropout,
  168. activation_fn=act_fn,
  169. )
  170. for _ in range(n_blocks)
  171. ]
  172. )
  173. upsample = (
  174. Upsample1D(output_channel, use_conv_transpose=True)
  175. if not is_last
  176. else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  177. )
  178. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  179. self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
  180. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  181. self.initialize_weights()
  182. def initialize_weights(self):
  183. for m in self.modules():
  184. if isinstance(m, nn.Conv1d):
  185. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  186. if m.bias is not None:
  187. nn.init.constant_(m.bias, 0)
  188. elif isinstance(m, nn.GroupNorm):
  189. nn.init.constant_(m.weight, 1)
  190. nn.init.constant_(m.bias, 0)
  191. elif isinstance(m, nn.Linear):
  192. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  193. if m.bias is not None:
  194. nn.init.constant_(m.bias, 0)
  195. def forward(self, x, mask, mu, t, spks=None, cond=None):
  196. """Forward pass of the UNet1DConditional model.
  197. Args:
  198. x (torch.Tensor): shape (batch_size, in_channels, time)
  199. mask (_type_): shape (batch_size, 1, time)
  200. t (_type_): shape (batch_size)
  201. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  202. cond (_type_, optional): placeholder for future use. Defaults to None.
  203. Raises:
  204. ValueError: _description_
  205. ValueError: _description_
  206. Returns:
  207. _type_: _description_
  208. """
  209. t = self.time_embeddings(t).to(t.dtype)
  210. t = self.time_mlp(t)
  211. x = pack([x, mu], "b * t")[0]
  212. if spks is not None:
  213. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  214. x = pack([x, spks], "b * t")[0]
  215. if cond is not None:
  216. x = pack([x, cond], "b * t")[0]
  217. hiddens = []
  218. masks = [mask]
  219. for resnet, transformer_blocks, downsample in self.down_blocks:
  220. mask_down = masks[-1]
  221. x = resnet(x, mask_down, t)
  222. x = rearrange(x, "b c t -> b t c").contiguous()
  223. # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
  224. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
  225. attn_mask = mask_to_bias(attn_mask==1, x.dtype)
  226. for transformer_block in transformer_blocks:
  227. x = transformer_block(
  228. hidden_states=x,
  229. attention_mask=attn_mask,
  230. timestep=t,
  231. )
  232. x = rearrange(x, "b t c -> b c t").contiguous()
  233. hiddens.append(x) # Save hidden states for skip connections
  234. x = downsample(x * mask_down)
  235. masks.append(mask_down[:, :, ::2])
  236. masks = masks[:-1]
  237. mask_mid = masks[-1]
  238. for resnet, transformer_blocks in self.mid_blocks:
  239. x = resnet(x, mask_mid, t)
  240. x = rearrange(x, "b c t -> b t c").contiguous()
  241. # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
  242. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
  243. attn_mask = mask_to_bias(attn_mask==1, x.dtype)
  244. for transformer_block in transformer_blocks:
  245. x = transformer_block(
  246. hidden_states=x,
  247. attention_mask=attn_mask,
  248. timestep=t,
  249. )
  250. x = rearrange(x, "b t c -> b c t").contiguous()
  251. for resnet, transformer_blocks, upsample in self.up_blocks:
  252. mask_up = masks.pop()
  253. skip = hiddens.pop()
  254. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  255. x = resnet(x, mask_up, t)
  256. x = rearrange(x, "b c t -> b t c").contiguous()
  257. # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
  258. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
  259. attn_mask = mask_to_bias(attn_mask==1, x.dtype)
  260. for transformer_block in transformer_blocks:
  261. x = transformer_block(
  262. hidden_states=x,
  263. attention_mask=attn_mask,
  264. timestep=t,
  265. )
  266. x = rearrange(x, "b t c -> b c t").contiguous()
  267. x = upsample(x * mask_up)
  268. x = self.final_block(x, mask_up)
  269. output = self.final_proj(x * mask_up)
  270. return output * mask