decoder.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple, Optional, Dict, Any
  15. import torch
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from einops import pack, rearrange, repeat
  19. from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
  20. from cosyvoice.utils.common import mask_to_bias
  21. from cosyvoice.utils.mask import add_optional_chunk_mask
  22. from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
  23. from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
  24. class Transpose(torch.nn.Module):
  25. def __init__(self, dim0: int, dim1: int):
  26. super().__init__()
  27. self.dim0 = dim0
  28. self.dim1 = dim1
  29. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
  30. x = torch.transpose(x, self.dim0, self.dim1)
  31. return x
  32. class CausalConv1d(torch.nn.Conv1d):
  33. def __init__(
  34. self,
  35. in_channels: int,
  36. out_channels: int,
  37. kernel_size: int,
  38. stride: int = 1,
  39. dilation: int = 1,
  40. groups: int = 1,
  41. bias: bool = True,
  42. padding_mode: str = 'zeros',
  43. device=None,
  44. dtype=None
  45. ) -> None:
  46. super(CausalConv1d, self).__init__(in_channels, out_channels,
  47. kernel_size, stride,
  48. padding=0, dilation=dilation,
  49. groups=groups, bias=bias,
  50. padding_mode=padding_mode,
  51. device=device, dtype=dtype)
  52. assert stride == 1
  53. self.causal_padding = kernel_size - 1
  54. def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  55. if cache.size(2) == 0:
  56. x = F.pad(x, (self.causal_padding, 0), value=0.0)
  57. else:
  58. assert cache.size(2) == self.causal_padding
  59. x = torch.concat([cache, x], dim=2)
  60. cache = x[:, :, -self.causal_padding:]
  61. x = super(CausalConv1d, self).forward(x)
  62. return x, cache
  63. class CausalBlock1D(Block1D):
  64. def __init__(self, dim: int, dim_out: int):
  65. super(CausalBlock1D, self).__init__(dim, dim_out)
  66. self.block = torch.nn.Sequential(
  67. CausalConv1d(dim, dim_out, 3),
  68. Transpose(1, 2),
  69. nn.LayerNorm(dim_out),
  70. Transpose(1, 2),
  71. nn.Mish(),
  72. )
  73. def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  74. output, cache = self.block[0](x * mask, cache)
  75. for i in range(1, len(self.block)):
  76. output = self.block[i](output)
  77. return output * mask, cache
  78. class CausalResnetBlock1D(ResnetBlock1D):
  79. def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
  80. super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
  81. self.block1 = CausalBlock1D(dim, dim_out)
  82. self.block2 = CausalBlock1D(dim_out, dim_out)
  83. def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
  84. block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
  85. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  86. h, block1_cache = self.block1(x, mask, block1_cache)
  87. h += self.mlp(time_emb).unsqueeze(-1)
  88. h, block2_cache = self.block2(h, mask, block2_cache)
  89. output = h + self.res_conv(x * mask)
  90. return output, block1_cache, block2_cache
  91. class CausalAttnProcessor2_0(AttnProcessor2_0):
  92. r"""
  93. Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
  94. """
  95. def __init__(self):
  96. super(CausalAttnProcessor2_0, self).__init__()
  97. def __call__(
  98. self,
  99. attn: Attention,
  100. hidden_states: torch.FloatTensor,
  101. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  102. attention_mask: Optional[torch.FloatTensor] = None,
  103. temb: Optional[torch.FloatTensor] = None,
  104. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  105. *args,
  106. **kwargs,
  107. ) -> Tuple[torch.FloatTensor, torch.Tensor]:
  108. if len(args) > 0 or kwargs.get("scale", None) is not None:
  109. deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
  110. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
  111. deprecate("scale", "1.0.0", deprecation_message)
  112. residual = hidden_states
  113. if attn.spatial_norm is not None:
  114. hidden_states = attn.spatial_norm(hidden_states, temb)
  115. input_ndim = hidden_states.ndim
  116. if input_ndim == 4:
  117. batch_size, channel, height, width = hidden_states.shape
  118. hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
  119. batch_size, sequence_length, _ = (
  120. hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
  121. )
  122. if attention_mask is not None:
  123. # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
  124. # scaled_dot_product_attention expects attention_mask shape to be
  125. # (batch, heads, source_length, target_length)
  126. attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
  127. if attn.group_norm is not None:
  128. hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
  129. query = attn.to_q(hidden_states)
  130. if encoder_hidden_states is None:
  131. encoder_hidden_states = hidden_states
  132. elif attn.norm_cross:
  133. encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
  134. key_cache = attn.to_k(encoder_hidden_states)
  135. value_cache = attn.to_v(encoder_hidden_states)
  136. # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
  137. if cache.size(0) != 0:
  138. key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
  139. value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
  140. else:
  141. key, value = key_cache, value_cache
  142. cache = torch.stack([key_cache, value_cache], dim=3)
  143. inner_dim = key.shape[-1]
  144. head_dim = inner_dim // attn.heads
  145. query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  146. key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  147. value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  148. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  149. # TODO: add support for attn.scale when we move to Torch 2.1
  150. hidden_states = F.scaled_dot_product_attention(
  151. query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
  152. )
  153. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
  154. hidden_states = hidden_states.to(query.dtype)
  155. # linear proj
  156. hidden_states = attn.to_out[0](hidden_states)
  157. # dropout
  158. hidden_states = attn.to_out[1](hidden_states)
  159. if input_ndim == 4:
  160. hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
  161. if attn.residual_connection:
  162. hidden_states = hidden_states + residual
  163. hidden_states = hidden_states / attn.rescale_output_factor
  164. return hidden_states, cache
  165. @maybe_allow_in_graph
  166. class CausalAttention(Attention):
  167. def __init__(
  168. self,
  169. query_dim: int,
  170. cross_attention_dim: Optional[int] = None,
  171. heads: int = 8,
  172. dim_head: int = 64,
  173. dropout: float = 0.0,
  174. bias: bool = False,
  175. upcast_attention: bool = False,
  176. upcast_softmax: bool = False,
  177. cross_attention_norm: Optional[str] = None,
  178. cross_attention_norm_num_groups: int = 32,
  179. qk_norm: Optional[str] = None,
  180. added_kv_proj_dim: Optional[int] = None,
  181. norm_num_groups: Optional[int] = None,
  182. spatial_norm_dim: Optional[int] = None,
  183. out_bias: bool = True,
  184. scale_qk: bool = True,
  185. only_cross_attention: bool = False,
  186. eps: float = 1e-5,
  187. rescale_output_factor: float = 1.0,
  188. residual_connection: bool = False,
  189. _from_deprecated_attn_block: bool = False,
  190. processor: Optional["AttnProcessor2_0"] = None,
  191. out_dim: int = None,
  192. ):
  193. super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
  194. cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
  195. spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
  196. _from_deprecated_attn_block, processor, out_dim)
  197. processor = CausalAttnProcessor2_0()
  198. self.set_processor(processor)
  199. def forward(
  200. self,
  201. hidden_states: torch.FloatTensor,
  202. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  203. attention_mask: Optional[torch.FloatTensor] = None,
  204. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  205. **cross_attention_kwargs,
  206. ) -> Tuple[torch.Tensor, torch.Tensor]:
  207. r"""
  208. The forward method of the `Attention` class.
  209. Args:
  210. hidden_states (`torch.Tensor`):
  211. The hidden states of the query.
  212. encoder_hidden_states (`torch.Tensor`, *optional*):
  213. The hidden states of the encoder.
  214. attention_mask (`torch.Tensor`, *optional*):
  215. The attention mask to use. If `None`, no mask is applied.
  216. **cross_attention_kwargs:
  217. Additional keyword arguments to pass along to the cross attention.
  218. Returns:
  219. `torch.Tensor`: The output of the attention layer.
  220. """
  221. # The `Attention` class can call different attention processors / attention functions
  222. # here we simply pass along all tensors to the selected processor class
  223. # For standard processors that are defined here, `**cross_attention_kwargs` is empty
  224. attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
  225. unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
  226. if len(unused_kwargs) > 0:
  227. logger.warning(
  228. f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
  229. )
  230. cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
  231. return self.processor(
  232. self,
  233. hidden_states,
  234. encoder_hidden_states=encoder_hidden_states,
  235. attention_mask=attention_mask,
  236. cache=cache,
  237. **cross_attention_kwargs,
  238. )
  239. @maybe_allow_in_graph
  240. class CausalBasicTransformerBlock(BasicTransformerBlock):
  241. def __init__(
  242. self,
  243. dim: int,
  244. num_attention_heads: int,
  245. attention_head_dim: int,
  246. dropout=0.0,
  247. cross_attention_dim: Optional[int] = None,
  248. activation_fn: str = "geglu",
  249. num_embeds_ada_norm: Optional[int] = None,
  250. attention_bias: bool = False,
  251. only_cross_attention: bool = False,
  252. double_self_attention: bool = False,
  253. upcast_attention: bool = False,
  254. norm_elementwise_affine: bool = True,
  255. norm_type: str = "layer_norm",
  256. final_dropout: bool = False,
  257. ):
  258. super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
  259. cross_attention_dim, activation_fn, num_embeds_ada_norm,
  260. attention_bias, only_cross_attention, double_self_attention,
  261. upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
  262. self.attn1 = CausalAttention(
  263. query_dim=dim,
  264. heads=num_attention_heads,
  265. dim_head=attention_head_dim,
  266. dropout=dropout,
  267. bias=attention_bias,
  268. cross_attention_dim=cross_attention_dim if only_cross_attention else None,
  269. upcast_attention=upcast_attention,
  270. )
  271. def forward(
  272. self,
  273. hidden_states: torch.FloatTensor,
  274. attention_mask: Optional[torch.FloatTensor] = None,
  275. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  276. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  277. timestep: Optional[torch.LongTensor] = None,
  278. cross_attention_kwargs: Dict[str, Any] = None,
  279. class_labels: Optional[torch.LongTensor] = None,
  280. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  281. ) -> Tuple[torch.Tensor, torch.Tensor]:
  282. # Notice that normalization is always applied before the real computation in the following blocks.
  283. # 1. Self-Attention
  284. if self.use_ada_layer_norm:
  285. norm_hidden_states = self.norm1(hidden_states, timestep)
  286. elif self.use_ada_layer_norm_zero:
  287. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
  288. hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
  289. )
  290. else:
  291. norm_hidden_states = self.norm1(hidden_states)
  292. cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
  293. attn_output, cache = self.attn1(
  294. norm_hidden_states,
  295. encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
  296. attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
  297. cache=cache,
  298. **cross_attention_kwargs,
  299. )
  300. if self.use_ada_layer_norm_zero:
  301. attn_output = gate_msa.unsqueeze(1) * attn_output
  302. hidden_states = attn_output + hidden_states
  303. # 2. Cross-Attention
  304. if self.attn2 is not None:
  305. norm_hidden_states = (
  306. self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
  307. )
  308. attn_output = self.attn2(
  309. norm_hidden_states,
  310. encoder_hidden_states=encoder_hidden_states,
  311. attention_mask=encoder_attention_mask,
  312. **cross_attention_kwargs,
  313. )
  314. hidden_states = attn_output + hidden_states
  315. # 3. Feed-forward
  316. norm_hidden_states = self.norm3(hidden_states)
  317. if self.use_ada_layer_norm_zero:
  318. norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
  319. if self._chunk_size is not None:
  320. # "feed_forward_chunk_size" can be used to save memory
  321. if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
  322. raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
  323. {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
  324. num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
  325. ff_output = torch.cat(
  326. [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
  327. dim=self._chunk_dim,
  328. )
  329. else:
  330. ff_output = self.ff(norm_hidden_states)
  331. if self.use_ada_layer_norm_zero:
  332. ff_output = gate_mlp.unsqueeze(1) * ff_output
  333. hidden_states = ff_output + hidden_states
  334. return hidden_states, cache
  335. class ConditionalDecoder(nn.Module):
  336. def __init__(
  337. self,
  338. in_channels,
  339. out_channels,
  340. channels=(256, 256),
  341. dropout=0.05,
  342. attention_head_dim=64,
  343. n_blocks=1,
  344. num_mid_blocks=2,
  345. num_heads=4,
  346. act_fn="snake",
  347. ):
  348. """
  349. This decoder requires an input with the same shape of the target. So, if your text content
  350. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  351. """
  352. super().__init__()
  353. channels = tuple(channels)
  354. self.in_channels = in_channels
  355. self.out_channels = out_channels
  356. self.time_embeddings = SinusoidalPosEmb(in_channels)
  357. time_embed_dim = channels[0] * 4
  358. self.time_mlp = TimestepEmbedding(
  359. in_channels=in_channels,
  360. time_embed_dim=time_embed_dim,
  361. act_fn="silu",
  362. )
  363. self.down_blocks = nn.ModuleList([])
  364. self.mid_blocks = nn.ModuleList([])
  365. self.up_blocks = nn.ModuleList([])
  366. output_channel = in_channels
  367. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  368. input_channel = output_channel
  369. output_channel = channels[i]
  370. is_last = i == len(channels) - 1
  371. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  372. transformer_blocks = nn.ModuleList(
  373. [
  374. BasicTransformerBlock(
  375. dim=output_channel,
  376. num_attention_heads=num_heads,
  377. attention_head_dim=attention_head_dim,
  378. dropout=dropout,
  379. activation_fn=act_fn,
  380. )
  381. for _ in range(n_blocks)
  382. ]
  383. )
  384. downsample = (
  385. Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  386. )
  387. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  388. for _ in range(num_mid_blocks):
  389. input_channel = channels[-1]
  390. out_channels = channels[-1]
  391. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  392. transformer_blocks = nn.ModuleList(
  393. [
  394. BasicTransformerBlock(
  395. dim=output_channel,
  396. num_attention_heads=num_heads,
  397. attention_head_dim=attention_head_dim,
  398. dropout=dropout,
  399. activation_fn=act_fn,
  400. )
  401. for _ in range(n_blocks)
  402. ]
  403. )
  404. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  405. channels = channels[::-1] + (channels[0],)
  406. for i in range(len(channels) - 1):
  407. input_channel = channels[i] * 2
  408. output_channel = channels[i + 1]
  409. is_last = i == len(channels) - 2
  410. resnet = ResnetBlock1D(
  411. dim=input_channel,
  412. dim_out=output_channel,
  413. time_emb_dim=time_embed_dim,
  414. )
  415. transformer_blocks = nn.ModuleList(
  416. [
  417. BasicTransformerBlock(
  418. dim=output_channel,
  419. num_attention_heads=num_heads,
  420. attention_head_dim=attention_head_dim,
  421. dropout=dropout,
  422. activation_fn=act_fn,
  423. )
  424. for _ in range(n_blocks)
  425. ]
  426. )
  427. upsample = (
  428. Upsample1D(output_channel, use_conv_transpose=True)
  429. if not is_last
  430. else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  431. )
  432. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  433. self.final_block = Block1D(channels[-1], channels[-1])
  434. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  435. self.initialize_weights()
  436. def initialize_weights(self):
  437. for m in self.modules():
  438. if isinstance(m, nn.Conv1d):
  439. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  440. if m.bias is not None:
  441. nn.init.constant_(m.bias, 0)
  442. elif isinstance(m, nn.GroupNorm):
  443. nn.init.constant_(m.weight, 1)
  444. nn.init.constant_(m.bias, 0)
  445. elif isinstance(m, nn.Linear):
  446. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  447. if m.bias is not None:
  448. nn.init.constant_(m.bias, 0)
  449. def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  450. """Forward pass of the UNet1DConditional model.
  451. Args:
  452. x (torch.Tensor): shape (batch_size, in_channels, time)
  453. mask (_type_): shape (batch_size, 1, time)
  454. t (_type_): shape (batch_size)
  455. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  456. cond (_type_, optional): placeholder for future use. Defaults to None.
  457. Raises:
  458. ValueError: _description_
  459. ValueError: _description_
  460. Returns:
  461. _type_: _description_
  462. """
  463. t = self.time_embeddings(t).to(t.dtype)
  464. t = self.time_mlp(t)
  465. x = pack([x, mu], "b * t")[0]
  466. if spks is not None:
  467. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  468. x = pack([x, spks], "b * t")[0]
  469. if cond is not None:
  470. x = pack([x, cond], "b * t")[0]
  471. hiddens = []
  472. masks = [mask]
  473. for resnet, transformer_blocks, downsample in self.down_blocks:
  474. mask_down = masks[-1]
  475. x = resnet(x, mask_down, t)
  476. x = rearrange(x, "b c t -> b t c").contiguous()
  477. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  478. attn_mask = mask_to_bias(attn_mask, x.dtype)
  479. for transformer_block in transformer_blocks:
  480. x = transformer_block(
  481. hidden_states=x,
  482. attention_mask=attn_mask,
  483. timestep=t,
  484. )
  485. x = rearrange(x, "b t c -> b c t").contiguous()
  486. hiddens.append(x) # Save hidden states for skip connections
  487. x = downsample(x * mask_down)
  488. masks.append(mask_down[:, :, ::2])
  489. masks = masks[:-1]
  490. mask_mid = masks[-1]
  491. for resnet, transformer_blocks in self.mid_blocks:
  492. x = resnet(x, mask_mid, t)
  493. x = rearrange(x, "b c t -> b t c").contiguous()
  494. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  495. attn_mask = mask_to_bias(attn_mask, x.dtype)
  496. for transformer_block in transformer_blocks:
  497. x = transformer_block(
  498. hidden_states=x,
  499. attention_mask=attn_mask,
  500. timestep=t,
  501. )
  502. x = rearrange(x, "b t c -> b c t").contiguous()
  503. for resnet, transformer_blocks, upsample in self.up_blocks:
  504. mask_up = masks.pop()
  505. skip = hiddens.pop()
  506. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  507. x = resnet(x, mask_up, t)
  508. x = rearrange(x, "b c t -> b t c").contiguous()
  509. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  510. attn_mask = mask_to_bias(attn_mask, x.dtype)
  511. for transformer_block in transformer_blocks:
  512. x = transformer_block(
  513. hidden_states=x,
  514. attention_mask=attn_mask,
  515. timestep=t,
  516. )
  517. x = rearrange(x, "b t c -> b c t").contiguous()
  518. x = upsample(x * mask_up)
  519. x = self.final_block(x, mask_up)
  520. output = self.final_proj(x * mask_up)
  521. return output * mask
  522. class CausalConditionalDecoder(ConditionalDecoder):
  523. def __init__(
  524. self,
  525. in_channels,
  526. out_channels,
  527. channels=(256, 256),
  528. dropout=0.05,
  529. attention_head_dim=64,
  530. n_blocks=1,
  531. num_mid_blocks=2,
  532. num_heads=4,
  533. act_fn="snake",
  534. static_chunk_size=50,
  535. num_decoding_left_chunks=2,
  536. ):
  537. """
  538. This decoder requires an input with the same shape of the target. So, if your text content
  539. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  540. """
  541. torch.nn.Module.__init__(self)
  542. channels = tuple(channels)
  543. self.in_channels = in_channels
  544. self.out_channels = out_channels
  545. self.time_embeddings = SinusoidalPosEmb(in_channels)
  546. time_embed_dim = channels[0] * 4
  547. self.time_mlp = TimestepEmbedding(
  548. in_channels=in_channels,
  549. time_embed_dim=time_embed_dim,
  550. act_fn="silu",
  551. )
  552. self.static_chunk_size = static_chunk_size
  553. self.num_decoding_left_chunks = num_decoding_left_chunks
  554. self.down_blocks = nn.ModuleList([])
  555. self.mid_blocks = nn.ModuleList([])
  556. self.up_blocks = nn.ModuleList([])
  557. output_channel = in_channels
  558. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  559. input_channel = output_channel
  560. output_channel = channels[i]
  561. is_last = i == len(channels) - 1
  562. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  563. transformer_blocks = nn.ModuleList(
  564. [
  565. CausalBasicTransformerBlock(
  566. dim=output_channel,
  567. num_attention_heads=num_heads,
  568. attention_head_dim=attention_head_dim,
  569. dropout=dropout,
  570. activation_fn=act_fn,
  571. )
  572. for _ in range(n_blocks)
  573. ]
  574. )
  575. downsample = (
  576. Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
  577. )
  578. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  579. for _ in range(num_mid_blocks):
  580. input_channel = channels[-1]
  581. out_channels = channels[-1]
  582. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  583. transformer_blocks = nn.ModuleList(
  584. [
  585. CausalBasicTransformerBlock(
  586. dim=output_channel,
  587. num_attention_heads=num_heads,
  588. attention_head_dim=attention_head_dim,
  589. dropout=dropout,
  590. activation_fn=act_fn,
  591. )
  592. for _ in range(n_blocks)
  593. ]
  594. )
  595. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  596. channels = channels[::-1] + (channels[0],)
  597. for i in range(len(channels) - 1):
  598. input_channel = channels[i] * 2
  599. output_channel = channels[i + 1]
  600. is_last = i == len(channels) - 2
  601. resnet = CausalResnetBlock1D(
  602. dim=input_channel,
  603. dim_out=output_channel,
  604. time_emb_dim=time_embed_dim,
  605. )
  606. transformer_blocks = nn.ModuleList(
  607. [
  608. CausalBasicTransformerBlock(
  609. dim=output_channel,
  610. num_attention_heads=num_heads,
  611. attention_head_dim=attention_head_dim,
  612. dropout=dropout,
  613. activation_fn=act_fn,
  614. )
  615. for _ in range(n_blocks)
  616. ]
  617. )
  618. upsample = (
  619. Upsample1D(output_channel, use_conv_transpose=True)
  620. if not is_last
  621. else CausalConv1d(output_channel, output_channel, 3)
  622. )
  623. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  624. self.final_block = CausalBlock1D(channels[-1], channels[-1])
  625. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  626. self.initialize_weights()
  627. def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  628. """Forward pass of the UNet1DConditional model.
  629. Args:
  630. x (torch.Tensor): shape (batch_size, in_channels, time)
  631. mask (_type_): shape (batch_size, 1, time)
  632. t (_type_): shape (batch_size)
  633. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  634. cond (_type_, optional): placeholder for future use. Defaults to None.
  635. Raises:
  636. ValueError: _description_
  637. ValueError: _description_
  638. Returns:
  639. _type_: _description_
  640. """
  641. t = self.time_embeddings(t).to(t.dtype)
  642. t = self.time_mlp(t)
  643. x = pack([x, mu], "b * t")[0]
  644. if spks is not None:
  645. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  646. x = pack([x, spks], "b * t")[0]
  647. if cond is not None:
  648. x = pack([x, cond], "b * t")[0]
  649. hiddens = []
  650. masks = [mask]
  651. for resnet, transformer_blocks, downsample in self.down_blocks:
  652. mask_down = masks[-1]
  653. x, _, _ = resnet(x, mask_down, t)
  654. x = rearrange(x, "b c t -> b t c").contiguous()
  655. if streaming is True:
  656. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  657. else:
  658. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  659. attn_mask = mask_to_bias(attn_mask, x.dtype)
  660. for transformer_block in transformer_blocks:
  661. x, _ = transformer_block(
  662. hidden_states=x,
  663. attention_mask=attn_mask,
  664. timestep=t,
  665. )
  666. x = rearrange(x, "b t c -> b c t").contiguous()
  667. hiddens.append(x) # Save hidden states for skip connections
  668. x, _ = downsample(x * mask_down)
  669. masks.append(mask_down[:, :, ::2])
  670. masks = masks[:-1]
  671. mask_mid = masks[-1]
  672. for resnet, transformer_blocks in self.mid_blocks:
  673. x, _, _ = resnet(x, mask_mid, t)
  674. x = rearrange(x, "b c t -> b t c").contiguous()
  675. if streaming is True:
  676. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  677. else:
  678. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  679. attn_mask = mask_to_bias(attn_mask, x.dtype)
  680. for transformer_block in transformer_blocks:
  681. x, _ = transformer_block(
  682. hidden_states=x,
  683. attention_mask=attn_mask,
  684. timestep=t,
  685. )
  686. x = rearrange(x, "b t c -> b c t").contiguous()
  687. for resnet, transformer_blocks, upsample in self.up_blocks:
  688. mask_up = masks.pop()
  689. skip = hiddens.pop()
  690. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  691. x, _, _ = resnet(x, mask_up, t)
  692. x = rearrange(x, "b c t -> b t c").contiguous()
  693. if streaming is True:
  694. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  695. else:
  696. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  697. attn_mask = mask_to_bias(attn_mask, x.dtype)
  698. for transformer_block in transformer_blocks:
  699. x, _ = transformer_block(
  700. hidden_states=x,
  701. attention_mask=attn_mask,
  702. timestep=t,
  703. )
  704. x = rearrange(x, "b t c -> b c t").contiguous()
  705. x, _ = upsample(x * mask_up)
  706. x, _ = self.final_block(x, mask_up)
  707. output = self.final_proj(x * mask_up)
  708. return output * mask
  709. def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
  710. down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  711. down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  712. mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  713. mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  714. up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  715. up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  716. final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
  717. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  718. """Forward pass of the UNet1DConditional model.
  719. Args:
  720. x (torch.Tensor): shape (batch_size, in_channels, time)
  721. mask (_type_): shape (batch_size, 1, time)
  722. t (_type_): shape (batch_size)
  723. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  724. cond (_type_, optional): placeholder for future use. Defaults to None.
  725. Raises:
  726. ValueError: _description_
  727. ValueError: _description_
  728. Returns:
  729. _type_: _description_
  730. """
  731. t = self.time_embeddings(t).to(t.dtype)
  732. t = self.time_mlp(t)
  733. x = pack([x, mu], "b * t")[0]
  734. if spks is not None:
  735. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  736. x = pack([x, spks], "b * t")[0]
  737. if cond is not None:
  738. x = pack([x, cond], "b * t")[0]
  739. hiddens = []
  740. masks = [mask]
  741. down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
  742. mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
  743. up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
  744. for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
  745. mask_down = masks[-1]
  746. x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
  747. resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
  748. x = rearrange(x, "b c t -> b t c").contiguous()
  749. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
  750. attn_mask = mask_to_bias(attn_mask, x.dtype)
  751. for i, transformer_block in enumerate(transformer_blocks):
  752. x, down_blocks_kv_cache_new[index, i] = transformer_block(
  753. hidden_states=x,
  754. attention_mask=attn_mask,
  755. timestep=t,
  756. cache=down_blocks_kv_cache[index, i],
  757. )
  758. x = rearrange(x, "b t c -> b c t").contiguous()
  759. hiddens.append(x) # Save hidden states for skip connections
  760. x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
  761. masks.append(mask_down[:, :, ::2])
  762. masks = masks[:-1]
  763. mask_mid = masks[-1]
  764. for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
  765. x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
  766. resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
  767. x = rearrange(x, "b c t -> b t c").contiguous()
  768. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
  769. attn_mask = mask_to_bias(attn_mask, x.dtype)
  770. for i, transformer_block in enumerate(transformer_blocks):
  771. x, mid_blocks_kv_cache_new[index, i] = transformer_block(
  772. hidden_states=x,
  773. attention_mask=attn_mask,
  774. timestep=t,
  775. cache=mid_blocks_kv_cache[index, i]
  776. )
  777. x = rearrange(x, "b t c -> b c t").contiguous()
  778. for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
  779. mask_up = masks.pop()
  780. skip = hiddens.pop()
  781. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  782. x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
  783. resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
  784. x = rearrange(x, "b c t -> b t c").contiguous()
  785. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
  786. attn_mask = mask_to_bias(attn_mask, x.dtype)
  787. for i, transformer_block in enumerate(transformer_blocks):
  788. x, up_blocks_kv_cache_new[index, i] = transformer_block(
  789. hidden_states=x,
  790. attention_mask=attn_mask,
  791. timestep=t,
  792. cache=up_blocks_kv_cache[index, i]
  793. )
  794. x = rearrange(x, "b t c -> b c t").contiguous()
  795. x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
  796. x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
  797. output = self.final_proj(x * mask_up)
  798. return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
  799. up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache