decoder.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple, Optional, Dict, Any
  15. import torch
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from einops import pack, rearrange, repeat
  19. from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
  20. from cosyvoice.utils.common import mask_to_bias
  21. from cosyvoice.utils.mask import add_optional_chunk_mask
  22. from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
  23. from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
  24. class Transpose(torch.nn.Module):
  25. def __init__(self, dim0: int, dim1: int):
  26. super().__init__()
  27. self.dim0 = dim0
  28. self.dim1 = dim1
  29. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
  30. x = torch.transpose(x, self.dim0, self.dim1)
  31. return x
  32. class CausalConv1d(torch.nn.Conv1d):
  33. def __init__(
  34. self,
  35. in_channels: int,
  36. out_channels: int,
  37. kernel_size: int,
  38. stride: int = 1,
  39. dilation: int = 1,
  40. groups: int = 1,
  41. bias: bool = True,
  42. padding_mode: str = 'zeros',
  43. device=None,
  44. dtype=None
  45. ) -> None:
  46. super(CausalConv1d, self).__init__(in_channels, out_channels,
  47. kernel_size, stride,
  48. padding=0, dilation=dilation,
  49. groups=groups, bias=bias,
  50. padding_mode=padding_mode,
  51. device=device, dtype=dtype)
  52. assert stride == 1
  53. self.causal_padding = kernel_size - 1
  54. def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  55. if cache.size(2) == 0:
  56. x = F.pad(x, (self.causal_padding, 0), value=0.0)
  57. else:
  58. assert cache.size(2) == self.causal_padding
  59. x = torch.concat([cache, x], dim=2)
  60. cache = x[:, :, -self.causal_padding:]
  61. x = super(CausalConv1d, self).forward(x)
  62. return x, cache
  63. class CausalBlock1D(Block1D):
  64. def __init__(self, dim: int, dim_out: int):
  65. super(CausalBlock1D, self).__init__(dim, dim_out)
  66. self.block = torch.nn.Sequential(
  67. CausalConv1d(dim, dim_out, 3),
  68. Transpose(1, 2),
  69. nn.LayerNorm(dim_out),
  70. Transpose(1, 2),
  71. nn.Mish(),
  72. )
  73. def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  74. output, cache = self.block[0](x * mask, cache)
  75. for i in range(1, len(self.block)):
  76. output = self.block[i](output)
  77. return output * mask, cache
  78. class CausalResnetBlock1D(ResnetBlock1D):
  79. def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
  80. super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
  81. self.block1 = CausalBlock1D(dim, dim_out)
  82. self.block2 = CausalBlock1D(dim_out, dim_out)
  83. def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
  84. block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
  85. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  86. h, block1_cache = self.block1(x, mask, block1_cache)
  87. h += self.mlp(time_emb).unsqueeze(-1)
  88. h, block2_cache = self.block2(h, mask, block2_cache)
  89. output = h + self.res_conv(x * mask)
  90. return output, block1_cache, block2_cache
  91. class CausalAttnProcessor2_0(AttnProcessor2_0):
  92. r"""
  93. Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
  94. """
  95. def __init__(self):
  96. super(CausalAttnProcessor2_0, self).__init__()
  97. def __call__(
  98. self,
  99. attn: Attention,
  100. hidden_states: torch.FloatTensor,
  101. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  102. attention_mask: Optional[torch.FloatTensor] = None,
  103. temb: Optional[torch.FloatTensor] = None,
  104. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  105. *args,
  106. **kwargs,
  107. ) -> Tuple[torch.FloatTensor, torch.Tensor]:
  108. if len(args) > 0 or kwargs.get("scale", None) is not None:
  109. deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
  110. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
  111. deprecate("scale", "1.0.0", deprecation_message)
  112. residual = hidden_states
  113. if attn.spatial_norm is not None:
  114. hidden_states = attn.spatial_norm(hidden_states, temb)
  115. input_ndim = hidden_states.ndim
  116. if input_ndim == 4:
  117. batch_size, channel, height, width = hidden_states.shape
  118. hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
  119. batch_size, sequence_length, _ = (
  120. hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
  121. )
  122. if attention_mask is not None:
  123. # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
  124. # scaled_dot_product_attention expects attention_mask shape to be
  125. # (batch, heads, source_length, target_length)
  126. attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
  127. if attn.group_norm is not None:
  128. hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
  129. query = attn.to_q(hidden_states)
  130. if encoder_hidden_states is None:
  131. encoder_hidden_states = hidden_states
  132. elif attn.norm_cross:
  133. encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
  134. key_cache = attn.to_k(encoder_hidden_states)
  135. value_cache = attn.to_v(encoder_hidden_states)
  136. # NOTE always concat cache for interface compatibility
  137. key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
  138. value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
  139. cache = torch.stack([key_cache, value_cache], dim=3)
  140. inner_dim = key.shape[-1]
  141. head_dim = inner_dim // attn.heads
  142. query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  143. key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  144. value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  145. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  146. # TODO: add support for attn.scale when we move to Torch 2.1
  147. hidden_states = F.scaled_dot_product_attention(
  148. query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
  149. )
  150. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
  151. hidden_states = hidden_states.to(query.dtype)
  152. # linear proj
  153. hidden_states = attn.to_out[0](hidden_states)
  154. # dropout
  155. hidden_states = attn.to_out[1](hidden_states)
  156. if input_ndim == 4:
  157. hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
  158. if attn.residual_connection:
  159. hidden_states = hidden_states + residual
  160. hidden_states = hidden_states / attn.rescale_output_factor
  161. return hidden_states, cache
  162. @maybe_allow_in_graph
  163. class CausalAttention(Attention):
  164. def __init__(
  165. self,
  166. query_dim: int,
  167. cross_attention_dim: Optional[int] = None,
  168. heads: int = 8,
  169. dim_head: int = 64,
  170. dropout: float = 0.0,
  171. bias: bool = False,
  172. upcast_attention: bool = False,
  173. upcast_softmax: bool = False,
  174. cross_attention_norm: Optional[str] = None,
  175. cross_attention_norm_num_groups: int = 32,
  176. qk_norm: Optional[str] = None,
  177. added_kv_proj_dim: Optional[int] = None,
  178. norm_num_groups: Optional[int] = None,
  179. spatial_norm_dim: Optional[int] = None,
  180. out_bias: bool = True,
  181. scale_qk: bool = True,
  182. only_cross_attention: bool = False,
  183. eps: float = 1e-5,
  184. rescale_output_factor: float = 1.0,
  185. residual_connection: bool = False,
  186. _from_deprecated_attn_block: bool = False,
  187. processor: Optional["AttnProcessor2_0"] = None,
  188. out_dim: int = None,
  189. ):
  190. super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
  191. cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
  192. spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
  193. _from_deprecated_attn_block, processor, out_dim)
  194. processor = CausalAttnProcessor2_0()
  195. self.set_processor(processor)
  196. def forward(
  197. self,
  198. hidden_states: torch.FloatTensor,
  199. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  200. attention_mask: Optional[torch.FloatTensor] = None,
  201. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  202. **cross_attention_kwargs,
  203. ) -> Tuple[torch.Tensor, torch.Tensor]:
  204. r"""
  205. The forward method of the `Attention` class.
  206. Args:
  207. hidden_states (`torch.Tensor`):
  208. The hidden states of the query.
  209. encoder_hidden_states (`torch.Tensor`, *optional*):
  210. The hidden states of the encoder.
  211. attention_mask (`torch.Tensor`, *optional*):
  212. The attention mask to use. If `None`, no mask is applied.
  213. **cross_attention_kwargs:
  214. Additional keyword arguments to pass along to the cross attention.
  215. Returns:
  216. `torch.Tensor`: The output of the attention layer.
  217. """
  218. # The `Attention` class can call different attention processors / attention functions
  219. # here we simply pass along all tensors to the selected processor class
  220. # For standard processors that are defined here, `**cross_attention_kwargs` is empty
  221. attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
  222. unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
  223. if len(unused_kwargs) > 0:
  224. logger.warning(
  225. f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
  226. )
  227. cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
  228. return self.processor(
  229. self,
  230. hidden_states,
  231. encoder_hidden_states=encoder_hidden_states,
  232. attention_mask=attention_mask,
  233. cache=cache,
  234. **cross_attention_kwargs,
  235. )
  236. @maybe_allow_in_graph
  237. class CausalBasicTransformerBlock(BasicTransformerBlock):
  238. def __init__(
  239. self,
  240. dim: int,
  241. num_attention_heads: int,
  242. attention_head_dim: int,
  243. dropout=0.0,
  244. cross_attention_dim: Optional[int] = None,
  245. activation_fn: str = "geglu",
  246. num_embeds_ada_norm: Optional[int] = None,
  247. attention_bias: bool = False,
  248. only_cross_attention: bool = False,
  249. double_self_attention: bool = False,
  250. upcast_attention: bool = False,
  251. norm_elementwise_affine: bool = True,
  252. norm_type: str = "layer_norm",
  253. final_dropout: bool = False,
  254. ):
  255. super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
  256. cross_attention_dim, activation_fn, num_embeds_ada_norm,
  257. attention_bias, only_cross_attention, double_self_attention,
  258. upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
  259. self.attn1 = CausalAttention(
  260. query_dim=dim,
  261. heads=num_attention_heads,
  262. dim_head=attention_head_dim,
  263. dropout=dropout,
  264. bias=attention_bias,
  265. cross_attention_dim=cross_attention_dim if only_cross_attention else None,
  266. upcast_attention=upcast_attention,
  267. )
  268. def forward(
  269. self,
  270. hidden_states: torch.FloatTensor,
  271. attention_mask: Optional[torch.FloatTensor] = None,
  272. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  273. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  274. timestep: Optional[torch.LongTensor] = None,
  275. cross_attention_kwargs: Dict[str, Any] = None,
  276. class_labels: Optional[torch.LongTensor] = None,
  277. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  278. ) -> Tuple[torch.Tensor, torch.Tensor]:
  279. # Notice that normalization is always applied before the real computation in the following blocks.
  280. # 1. Self-Attention
  281. if self.use_ada_layer_norm:
  282. norm_hidden_states = self.norm1(hidden_states, timestep)
  283. elif self.use_ada_layer_norm_zero:
  284. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
  285. hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
  286. )
  287. else:
  288. norm_hidden_states = self.norm1(hidden_states)
  289. cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
  290. attn_output, cache = self.attn1(
  291. norm_hidden_states,
  292. encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
  293. attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
  294. cache=cache,
  295. **cross_attention_kwargs,
  296. )
  297. if self.use_ada_layer_norm_zero:
  298. attn_output = gate_msa.unsqueeze(1) * attn_output
  299. hidden_states = attn_output + hidden_states
  300. # 2. Cross-Attention
  301. if self.attn2 is not None:
  302. norm_hidden_states = (
  303. self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
  304. )
  305. attn_output = self.attn2(
  306. norm_hidden_states,
  307. encoder_hidden_states=encoder_hidden_states,
  308. attention_mask=encoder_attention_mask,
  309. **cross_attention_kwargs,
  310. )
  311. hidden_states = attn_output + hidden_states
  312. # 3. Feed-forward
  313. norm_hidden_states = self.norm3(hidden_states)
  314. if self.use_ada_layer_norm_zero:
  315. norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
  316. if self._chunk_size is not None:
  317. # "feed_forward_chunk_size" can be used to save memory
  318. if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
  319. raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
  320. {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
  321. num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
  322. ff_output = torch.cat(
  323. [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
  324. dim=self._chunk_dim,
  325. )
  326. else:
  327. ff_output = self.ff(norm_hidden_states)
  328. if self.use_ada_layer_norm_zero:
  329. ff_output = gate_mlp.unsqueeze(1) * ff_output
  330. hidden_states = ff_output + hidden_states
  331. return hidden_states, cache
  332. class ConditionalDecoder(nn.Module):
  333. def __init__(
  334. self,
  335. in_channels,
  336. out_channels,
  337. channels=(256, 256),
  338. dropout=0.05,
  339. attention_head_dim=64,
  340. n_blocks=1,
  341. num_mid_blocks=2,
  342. num_heads=4,
  343. act_fn="snake",
  344. ):
  345. """
  346. This decoder requires an input with the same shape of the target. So, if your text content
  347. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  348. """
  349. super().__init__()
  350. channels = tuple(channels)
  351. self.in_channels = in_channels
  352. self.out_channels = out_channels
  353. self.time_embeddings = SinusoidalPosEmb(in_channels)
  354. time_embed_dim = channels[0] * 4
  355. self.time_mlp = TimestepEmbedding(
  356. in_channels=in_channels,
  357. time_embed_dim=time_embed_dim,
  358. act_fn="silu",
  359. )
  360. self.down_blocks = nn.ModuleList([])
  361. self.mid_blocks = nn.ModuleList([])
  362. self.up_blocks = nn.ModuleList([])
  363. output_channel = in_channels
  364. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  365. input_channel = output_channel
  366. output_channel = channels[i]
  367. is_last = i == len(channels) - 1
  368. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  369. transformer_blocks = nn.ModuleList(
  370. [
  371. BasicTransformerBlock(
  372. dim=output_channel,
  373. num_attention_heads=num_heads,
  374. attention_head_dim=attention_head_dim,
  375. dropout=dropout,
  376. activation_fn=act_fn,
  377. )
  378. for _ in range(n_blocks)
  379. ]
  380. )
  381. downsample = (
  382. Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  383. )
  384. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  385. for _ in range(num_mid_blocks):
  386. input_channel = channels[-1]
  387. out_channels = channels[-1]
  388. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  389. transformer_blocks = nn.ModuleList(
  390. [
  391. BasicTransformerBlock(
  392. dim=output_channel,
  393. num_attention_heads=num_heads,
  394. attention_head_dim=attention_head_dim,
  395. dropout=dropout,
  396. activation_fn=act_fn,
  397. )
  398. for _ in range(n_blocks)
  399. ]
  400. )
  401. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  402. channels = channels[::-1] + (channels[0],)
  403. for i in range(len(channels) - 1):
  404. input_channel = channels[i] * 2
  405. output_channel = channels[i + 1]
  406. is_last = i == len(channels) - 2
  407. resnet = ResnetBlock1D(
  408. dim=input_channel,
  409. dim_out=output_channel,
  410. time_emb_dim=time_embed_dim,
  411. )
  412. transformer_blocks = nn.ModuleList(
  413. [
  414. BasicTransformerBlock(
  415. dim=output_channel,
  416. num_attention_heads=num_heads,
  417. attention_head_dim=attention_head_dim,
  418. dropout=dropout,
  419. activation_fn=act_fn,
  420. )
  421. for _ in range(n_blocks)
  422. ]
  423. )
  424. upsample = (
  425. Upsample1D(output_channel, use_conv_transpose=True)
  426. if not is_last
  427. else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  428. )
  429. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  430. self.final_block = Block1D(channels[-1], channels[-1])
  431. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  432. self.initialize_weights()
  433. def initialize_weights(self):
  434. for m in self.modules():
  435. if isinstance(m, nn.Conv1d):
  436. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  437. if m.bias is not None:
  438. nn.init.constant_(m.bias, 0)
  439. elif isinstance(m, nn.GroupNorm):
  440. nn.init.constant_(m.weight, 1)
  441. nn.init.constant_(m.bias, 0)
  442. elif isinstance(m, nn.Linear):
  443. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  444. if m.bias is not None:
  445. nn.init.constant_(m.bias, 0)
  446. def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  447. """Forward pass of the UNet1DConditional model.
  448. Args:
  449. x (torch.Tensor): shape (batch_size, in_channels, time)
  450. mask (_type_): shape (batch_size, 1, time)
  451. t (_type_): shape (batch_size)
  452. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  453. cond (_type_, optional): placeholder for future use. Defaults to None.
  454. Raises:
  455. ValueError: _description_
  456. ValueError: _description_
  457. Returns:
  458. _type_: _description_
  459. """
  460. t = self.time_embeddings(t).to(t.dtype)
  461. t = self.time_mlp(t)
  462. x = pack([x, mu], "b * t")[0]
  463. if spks is not None:
  464. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  465. x = pack([x, spks], "b * t")[0]
  466. if cond is not None:
  467. x = pack([x, cond], "b * t")[0]
  468. hiddens = []
  469. masks = [mask]
  470. for resnet, transformer_blocks, downsample in self.down_blocks:
  471. mask_down = masks[-1]
  472. x = resnet(x, mask_down, t)
  473. x = rearrange(x, "b c t -> b t c").contiguous()
  474. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  475. attn_mask = mask_to_bias(attn_mask, x.dtype)
  476. for transformer_block in transformer_blocks:
  477. x = transformer_block(
  478. hidden_states=x,
  479. attention_mask=attn_mask,
  480. timestep=t,
  481. )
  482. x = rearrange(x, "b t c -> b c t").contiguous()
  483. hiddens.append(x) # Save hidden states for skip connections
  484. x = downsample(x * mask_down)
  485. masks.append(mask_down[:, :, ::2])
  486. masks = masks[:-1]
  487. mask_mid = masks[-1]
  488. for resnet, transformer_blocks in self.mid_blocks:
  489. x = resnet(x, mask_mid, t)
  490. x = rearrange(x, "b c t -> b t c").contiguous()
  491. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  492. attn_mask = mask_to_bias(attn_mask, x.dtype)
  493. for transformer_block in transformer_blocks:
  494. x = transformer_block(
  495. hidden_states=x,
  496. attention_mask=attn_mask,
  497. timestep=t,
  498. )
  499. x = rearrange(x, "b t c -> b c t").contiguous()
  500. for resnet, transformer_blocks, upsample in self.up_blocks:
  501. mask_up = masks.pop()
  502. skip = hiddens.pop()
  503. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  504. x = resnet(x, mask_up, t)
  505. x = rearrange(x, "b c t -> b t c").contiguous()
  506. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  507. attn_mask = mask_to_bias(attn_mask, x.dtype)
  508. for transformer_block in transformer_blocks:
  509. x = transformer_block(
  510. hidden_states=x,
  511. attention_mask=attn_mask,
  512. timestep=t,
  513. )
  514. x = rearrange(x, "b t c -> b c t").contiguous()
  515. x = upsample(x * mask_up)
  516. x = self.final_block(x, mask_up)
  517. output = self.final_proj(x * mask_up)
  518. return output * mask
  519. class CausalConditionalDecoder(ConditionalDecoder):
  520. def __init__(
  521. self,
  522. in_channels,
  523. out_channels,
  524. channels=(256, 256),
  525. dropout=0.05,
  526. attention_head_dim=64,
  527. n_blocks=1,
  528. num_mid_blocks=2,
  529. num_heads=4,
  530. act_fn="snake",
  531. static_chunk_size=50,
  532. num_decoding_left_chunks=2,
  533. ):
  534. """
  535. This decoder requires an input with the same shape of the target. So, if your text content
  536. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  537. """
  538. torch.nn.Module.__init__(self)
  539. channels = tuple(channels)
  540. self.in_channels = in_channels
  541. self.out_channels = out_channels
  542. self.time_embeddings = SinusoidalPosEmb(in_channels)
  543. time_embed_dim = channels[0] * 4
  544. self.time_mlp = TimestepEmbedding(
  545. in_channels=in_channels,
  546. time_embed_dim=time_embed_dim,
  547. act_fn="silu",
  548. )
  549. self.static_chunk_size = static_chunk_size
  550. self.num_decoding_left_chunks = num_decoding_left_chunks
  551. self.down_blocks = nn.ModuleList([])
  552. self.mid_blocks = nn.ModuleList([])
  553. self.up_blocks = nn.ModuleList([])
  554. output_channel = in_channels
  555. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  556. input_channel = output_channel
  557. output_channel = channels[i]
  558. is_last = i == len(channels) - 1
  559. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  560. transformer_blocks = nn.ModuleList(
  561. [
  562. CausalBasicTransformerBlock(
  563. dim=output_channel,
  564. num_attention_heads=num_heads,
  565. attention_head_dim=attention_head_dim,
  566. dropout=dropout,
  567. activation_fn=act_fn,
  568. )
  569. for _ in range(n_blocks)
  570. ]
  571. )
  572. downsample = (
  573. Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
  574. )
  575. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  576. for _ in range(num_mid_blocks):
  577. input_channel = channels[-1]
  578. out_channels = channels[-1]
  579. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  580. transformer_blocks = nn.ModuleList(
  581. [
  582. CausalBasicTransformerBlock(
  583. dim=output_channel,
  584. num_attention_heads=num_heads,
  585. attention_head_dim=attention_head_dim,
  586. dropout=dropout,
  587. activation_fn=act_fn,
  588. )
  589. for _ in range(n_blocks)
  590. ]
  591. )
  592. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  593. channels = channels[::-1] + (channels[0],)
  594. for i in range(len(channels) - 1):
  595. input_channel = channels[i] * 2
  596. output_channel = channels[i + 1]
  597. is_last = i == len(channels) - 2
  598. resnet = CausalResnetBlock1D(
  599. dim=input_channel,
  600. dim_out=output_channel,
  601. time_emb_dim=time_embed_dim,
  602. )
  603. transformer_blocks = nn.ModuleList(
  604. [
  605. CausalBasicTransformerBlock(
  606. dim=output_channel,
  607. num_attention_heads=num_heads,
  608. attention_head_dim=attention_head_dim,
  609. dropout=dropout,
  610. activation_fn=act_fn,
  611. )
  612. for _ in range(n_blocks)
  613. ]
  614. )
  615. upsample = (
  616. Upsample1D(output_channel, use_conv_transpose=True)
  617. if not is_last
  618. else CausalConv1d(output_channel, output_channel, 3)
  619. )
  620. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  621. self.final_block = CausalBlock1D(channels[-1], channels[-1])
  622. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  623. self.initialize_weights()
  624. def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  625. """Forward pass of the UNet1DConditional model.
  626. Args:
  627. x (torch.Tensor): shape (batch_size, in_channels, time)
  628. mask (_type_): shape (batch_size, 1, time)
  629. t (_type_): shape (batch_size)
  630. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  631. cond (_type_, optional): placeholder for future use. Defaults to None.
  632. Raises:
  633. ValueError: _description_
  634. ValueError: _description_
  635. Returns:
  636. _type_: _description_
  637. """
  638. t = self.time_embeddings(t).to(t.dtype)
  639. t = self.time_mlp(t)
  640. x = pack([x, mu], "b * t")[0]
  641. if spks is not None:
  642. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  643. x = pack([x, spks], "b * t")[0]
  644. if cond is not None:
  645. x = pack([x, cond], "b * t")[0]
  646. hiddens = []
  647. masks = [mask]
  648. for resnet, transformer_blocks, downsample in self.down_blocks:
  649. mask_down = masks[-1]
  650. x, _, _ = resnet(x, mask_down, t)
  651. x = rearrange(x, "b c t -> b t c").contiguous()
  652. if streaming is True:
  653. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  654. else:
  655. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  656. attn_mask = mask_to_bias(attn_mask, x.dtype)
  657. for transformer_block in transformer_blocks:
  658. x, _ = transformer_block(
  659. hidden_states=x,
  660. attention_mask=attn_mask,
  661. timestep=t,
  662. )
  663. x = rearrange(x, "b t c -> b c t").contiguous()
  664. hiddens.append(x) # Save hidden states for skip connections
  665. x, _ = downsample(x * mask_down)
  666. masks.append(mask_down[:, :, ::2])
  667. masks = masks[:-1]
  668. mask_mid = masks[-1]
  669. for resnet, transformer_blocks in self.mid_blocks:
  670. x, _, _ = resnet(x, mask_mid, t)
  671. x = rearrange(x, "b c t -> b t c").contiguous()
  672. if streaming is True:
  673. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  674. else:
  675. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  676. attn_mask = mask_to_bias(attn_mask, x.dtype)
  677. for transformer_block in transformer_blocks:
  678. x, _ = transformer_block(
  679. hidden_states=x,
  680. attention_mask=attn_mask,
  681. timestep=t,
  682. )
  683. x = rearrange(x, "b t c -> b c t").contiguous()
  684. for resnet, transformer_blocks, upsample in self.up_blocks:
  685. mask_up = masks.pop()
  686. skip = hiddens.pop()
  687. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  688. x, _, _ = resnet(x, mask_up, t)
  689. x = rearrange(x, "b c t -> b t c").contiguous()
  690. if streaming is True:
  691. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  692. else:
  693. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  694. attn_mask = mask_to_bias(attn_mask, x.dtype)
  695. for transformer_block in transformer_blocks:
  696. x, _ = transformer_block(
  697. hidden_states=x,
  698. attention_mask=attn_mask,
  699. timestep=t,
  700. )
  701. x = rearrange(x, "b t c -> b c t").contiguous()
  702. x, _ = upsample(x * mask_up)
  703. x, _ = self.final_block(x, mask_up)
  704. output = self.final_proj(x * mask_up)
  705. return output * mask
  706. @torch.inference_mode()
  707. def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
  708. down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  709. down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  710. mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  711. mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  712. up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  713. up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  714. final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
  715. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  716. """Forward pass of the UNet1DConditional model.
  717. Args:
  718. x (torch.Tensor): shape (batch_size, in_channels, time)
  719. mask (_type_): shape (batch_size, 1, time)
  720. t (_type_): shape (batch_size)
  721. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  722. cond (_type_, optional): placeholder for future use. Defaults to None.
  723. Raises:
  724. ValueError: _description_
  725. ValueError: _description_
  726. Returns:
  727. _type_: _description_
  728. """
  729. t = self.time_embeddings(t).to(t.dtype)
  730. t = self.time_mlp(t)
  731. x = pack([x, mu], "b * t")[0]
  732. if spks is not None:
  733. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  734. x = pack([x, spks], "b * t")[0]
  735. if cond is not None:
  736. x = pack([x, cond], "b * t")[0]
  737. hiddens = []
  738. masks = [mask]
  739. down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
  740. mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
  741. up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
  742. for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
  743. mask_down = masks[-1]
  744. x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
  745. resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
  746. x = rearrange(x, "b c t -> b t c").contiguous()
  747. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
  748. attn_mask = mask_to_bias(attn_mask, x.dtype)
  749. for i, transformer_block in enumerate(transformer_blocks):
  750. x, down_blocks_kv_cache_new[index, i] = transformer_block(
  751. hidden_states=x,
  752. attention_mask=attn_mask,
  753. timestep=t,
  754. cache=down_blocks_kv_cache[index, i],
  755. )
  756. x = rearrange(x, "b t c -> b c t").contiguous()
  757. hiddens.append(x) # Save hidden states for skip connections
  758. x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
  759. masks.append(mask_down[:, :, ::2])
  760. masks = masks[:-1]
  761. mask_mid = masks[-1]
  762. for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
  763. x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
  764. resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
  765. x = rearrange(x, "b c t -> b t c").contiguous()
  766. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
  767. attn_mask = mask_to_bias(attn_mask, x.dtype)
  768. for i, transformer_block in enumerate(transformer_blocks):
  769. x, mid_blocks_kv_cache_new[index, i] = transformer_block(
  770. hidden_states=x,
  771. attention_mask=attn_mask,
  772. timestep=t,
  773. cache=mid_blocks_kv_cache[index, i]
  774. )
  775. x = rearrange(x, "b t c -> b c t").contiguous()
  776. for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
  777. mask_up = masks.pop()
  778. skip = hiddens.pop()
  779. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  780. x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
  781. resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
  782. x = rearrange(x, "b c t -> b t c").contiguous()
  783. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
  784. attn_mask = mask_to_bias(attn_mask, x.dtype)
  785. for i, transformer_block in enumerate(transformer_blocks):
  786. x, up_blocks_kv_cache_new[index, i] = transformer_block(
  787. hidden_states=x,
  788. attention_mask=attn_mask,
  789. timestep=t,
  790. cache=up_blocks_kv_cache[index, i]
  791. )
  792. x = rearrange(x, "b t c -> b c t").contiguous()
  793. x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
  794. x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
  795. output = self.final_proj(x * mask_up)
  796. return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
  797. up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache