decoder.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple, Optional, Dict, Any
  15. import torch
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from einops import pack, rearrange, repeat
  19. from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
  20. from cosyvoice.utils.common import mask_to_bias
  21. from cosyvoice.utils.mask import add_optional_chunk_mask
  22. from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
  23. from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
  24. class Transpose(torch.nn.Module):
  25. def __init__(self, dim0: int, dim1: int):
  26. super().__init__()
  27. self.dim0 = dim0
  28. self.dim1 = dim1
  29. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
  30. x = torch.transpose(x, self.dim0, self.dim1)
  31. return x
  32. class CausalConv1d(torch.nn.Conv1d):
  33. def __init__(
  34. self,
  35. in_channels: int,
  36. out_channels: int,
  37. kernel_size: int,
  38. stride: int = 1,
  39. dilation: int = 1,
  40. groups: int = 1,
  41. bias: bool = True,
  42. padding_mode: str = 'zeros',
  43. device=None,
  44. dtype=None
  45. ) -> None:
  46. super(CausalConv1d, self).__init__(in_channels, out_channels,
  47. kernel_size, stride,
  48. padding=0, dilation=dilation,
  49. groups=groups, bias=bias,
  50. padding_mode=padding_mode,
  51. device=device, dtype=dtype)
  52. assert stride == 1
  53. self.causal_padding = kernel_size - 1
  54. def forward(self, x: torch.Tensor, cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  55. if cache.size(2) == 0:
  56. x = F.pad(x, (self.causal_padding, 0), value=0.0)
  57. else:
  58. assert cache.size(2) == self.causal_padding
  59. x = torch.concat([cache, x], dim=2)
  60. cache = x[:, :, -self.causal_padding:]
  61. x = super(CausalConv1d, self).forward(x)
  62. return x, cache
  63. class CausalBlock1D(Block1D):
  64. def __init__(self, dim: int, dim_out: int):
  65. super(CausalBlock1D, self).__init__(dim, dim_out)
  66. self.block = torch.nn.Sequential(
  67. CausalConv1d(dim, dim_out, 3),
  68. Transpose(1, 2),
  69. nn.LayerNorm(dim_out),
  70. Transpose(1, 2),
  71. nn.Mish(),
  72. )
  73. def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
  74. output, cache = self.block[0](x * mask, cache)
  75. for i in range(1, len(self.block)):
  76. output = self.block[i](output)
  77. return output * mask, cache
  78. class CausalResnetBlock1D(ResnetBlock1D):
  79. def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
  80. super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
  81. self.block1 = CausalBlock1D(dim, dim_out)
  82. self.block2 = CausalBlock1D(dim_out, dim_out)
  83. def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor, block1_cache: torch.Tensor=torch.zeros(0, 0, 0), block2_cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  84. h, block1_cache = self.block1(x, mask, block1_cache)
  85. h += self.mlp(time_emb).unsqueeze(-1)
  86. h, block2_cache = self.block2(h, mask, block2_cache)
  87. output = h + self.res_conv(x * mask)
  88. return output, block1_cache, block2_cache
  89. class CausalAttnProcessor2_0(AttnProcessor2_0):
  90. r"""
  91. Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
  92. """
  93. def __init__(self):
  94. super(CausalAttnProcessor2_0, self).__init__()
  95. def __call__(
  96. self,
  97. attn: Attention,
  98. hidden_states: torch.FloatTensor,
  99. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  100. attention_mask: Optional[torch.FloatTensor] = None,
  101. temb: Optional[torch.FloatTensor] = None,
  102. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  103. *args,
  104. **kwargs,
  105. ) -> Tuple[torch.FloatTensor, torch.Tensor]:
  106. if len(args) > 0 or kwargs.get("scale", None) is not None:
  107. deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
  108. deprecate("scale", "1.0.0", deprecation_message)
  109. residual = hidden_states
  110. if attn.spatial_norm is not None:
  111. hidden_states = attn.spatial_norm(hidden_states, temb)
  112. input_ndim = hidden_states.ndim
  113. if input_ndim == 4:
  114. batch_size, channel, height, width = hidden_states.shape
  115. hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
  116. batch_size, sequence_length, _ = (
  117. hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
  118. )
  119. if attention_mask is not None:
  120. # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
  121. # scaled_dot_product_attention expects attention_mask shape to be
  122. # (batch, heads, source_length, target_length)
  123. attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
  124. if attn.group_norm is not None:
  125. hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
  126. query = attn.to_q(hidden_states)
  127. if encoder_hidden_states is None:
  128. encoder_hidden_states = hidden_states
  129. elif attn.norm_cross:
  130. encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
  131. key_cache = attn.to_k(encoder_hidden_states)
  132. value_cache = attn.to_v(encoder_hidden_states)
  133. # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
  134. if cache.size(0) != 0:
  135. key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
  136. value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
  137. else:
  138. key, value = key_cache, value_cache
  139. cache = torch.stack([key_cache, value_cache], dim=3)
  140. inner_dim = key.shape[-1]
  141. head_dim = inner_dim // attn.heads
  142. query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  143. key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  144. value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  145. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  146. # TODO: add support for attn.scale when we move to Torch 2.1
  147. hidden_states = F.scaled_dot_product_attention(
  148. query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
  149. )
  150. hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
  151. hidden_states = hidden_states.to(query.dtype)
  152. # linear proj
  153. hidden_states = attn.to_out[0](hidden_states)
  154. # dropout
  155. hidden_states = attn.to_out[1](hidden_states)
  156. if input_ndim == 4:
  157. hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
  158. if attn.residual_connection:
  159. hidden_states = hidden_states + residual
  160. hidden_states = hidden_states / attn.rescale_output_factor
  161. return hidden_states, cache
  162. @maybe_allow_in_graph
  163. class CausalAttention(Attention):
  164. def __init__(
  165. self,
  166. query_dim: int,
  167. cross_attention_dim: Optional[int] = None,
  168. heads: int = 8,
  169. dim_head: int = 64,
  170. dropout: float = 0.0,
  171. bias: bool = False,
  172. upcast_attention: bool = False,
  173. upcast_softmax: bool = False,
  174. cross_attention_norm: Optional[str] = None,
  175. cross_attention_norm_num_groups: int = 32,
  176. qk_norm: Optional[str] = None,
  177. added_kv_proj_dim: Optional[int] = None,
  178. norm_num_groups: Optional[int] = None,
  179. spatial_norm_dim: Optional[int] = None,
  180. out_bias: bool = True,
  181. scale_qk: bool = True,
  182. only_cross_attention: bool = False,
  183. eps: float = 1e-5,
  184. rescale_output_factor: float = 1.0,
  185. residual_connection: bool = False,
  186. _from_deprecated_attn_block: bool = False,
  187. processor: Optional["AttnProcessor2_0"] = None,
  188. out_dim: int = None,
  189. ):
  190. super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, qk_norm,
  191. added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim)
  192. processor = CausalAttnProcessor2_0()
  193. self.set_processor(processor)
  194. def forward(
  195. self,
  196. hidden_states: torch.FloatTensor,
  197. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  198. attention_mask: Optional[torch.FloatTensor] = None,
  199. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  200. **cross_attention_kwargs,
  201. ) -> Tuple[torch.Tensor, torch.Tensor]:
  202. r"""
  203. The forward method of the `Attention` class.
  204. Args:
  205. hidden_states (`torch.Tensor`):
  206. The hidden states of the query.
  207. encoder_hidden_states (`torch.Tensor`, *optional*):
  208. The hidden states of the encoder.
  209. attention_mask (`torch.Tensor`, *optional*):
  210. The attention mask to use. If `None`, no mask is applied.
  211. **cross_attention_kwargs:
  212. Additional keyword arguments to pass along to the cross attention.
  213. Returns:
  214. `torch.Tensor`: The output of the attention layer.
  215. """
  216. # The `Attention` class can call different attention processors / attention functions
  217. # here we simply pass along all tensors to the selected processor class
  218. # For standard processors that are defined here, `**cross_attention_kwargs` is empty
  219. attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
  220. unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
  221. if len(unused_kwargs) > 0:
  222. logger.warning(
  223. f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
  224. )
  225. cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
  226. return self.processor(
  227. self,
  228. hidden_states,
  229. encoder_hidden_states=encoder_hidden_states,
  230. attention_mask=attention_mask,
  231. cache=cache,
  232. **cross_attention_kwargs,
  233. )
  234. @maybe_allow_in_graph
  235. class CausalBasicTransformerBlock(BasicTransformerBlock):
  236. def __init__(
  237. self,
  238. dim: int,
  239. num_attention_heads: int,
  240. attention_head_dim: int,
  241. dropout=0.0,
  242. cross_attention_dim: Optional[int] = None,
  243. activation_fn: str = "geglu",
  244. num_embeds_ada_norm: Optional[int] = None,
  245. attention_bias: bool = False,
  246. only_cross_attention: bool = False,
  247. double_self_attention: bool = False,
  248. upcast_attention: bool = False,
  249. norm_elementwise_affine: bool = True,
  250. norm_type: str = "layer_norm",
  251. final_dropout: bool = False,
  252. ):
  253. super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, activation_fn, num_embeds_ada_norm,
  254. attention_bias, only_cross_attention, double_self_attention, upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
  255. self.attn1 = CausalAttention(
  256. query_dim=dim,
  257. heads=num_attention_heads,
  258. dim_head=attention_head_dim,
  259. dropout=dropout,
  260. bias=attention_bias,
  261. cross_attention_dim=cross_attention_dim if only_cross_attention else None,
  262. upcast_attention=upcast_attention,
  263. )
  264. def forward(
  265. self,
  266. hidden_states: torch.FloatTensor,
  267. attention_mask: Optional[torch.FloatTensor] = None,
  268. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  269. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  270. timestep: Optional[torch.LongTensor] = None,
  271. cross_attention_kwargs: Dict[str, Any] = None,
  272. class_labels: Optional[torch.LongTensor] = None,
  273. cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  274. ) -> Tuple[torch.Tensor, torch.Tensor]:
  275. # Notice that normalization is always applied before the real computation in the following blocks.
  276. # 1. Self-Attention
  277. if self.use_ada_layer_norm:
  278. norm_hidden_states = self.norm1(hidden_states, timestep)
  279. elif self.use_ada_layer_norm_zero:
  280. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
  281. hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
  282. )
  283. else:
  284. norm_hidden_states = self.norm1(hidden_states)
  285. cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
  286. attn_output, cache = self.attn1(
  287. norm_hidden_states,
  288. encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
  289. attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
  290. cache=cache,
  291. **cross_attention_kwargs,
  292. )
  293. if self.use_ada_layer_norm_zero:
  294. attn_output = gate_msa.unsqueeze(1) * attn_output
  295. hidden_states = attn_output + hidden_states
  296. # 2. Cross-Attention
  297. if self.attn2 is not None:
  298. norm_hidden_states = (
  299. self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
  300. )
  301. attn_output = self.attn2(
  302. norm_hidden_states,
  303. encoder_hidden_states=encoder_hidden_states,
  304. attention_mask=encoder_attention_mask,
  305. **cross_attention_kwargs,
  306. )
  307. hidden_states = attn_output + hidden_states
  308. # 3. Feed-forward
  309. norm_hidden_states = self.norm3(hidden_states)
  310. if self.use_ada_layer_norm_zero:
  311. norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
  312. if self._chunk_size is not None:
  313. # "feed_forward_chunk_size" can be used to save memory
  314. if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
  315. raise ValueError(
  316. f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
  317. )
  318. num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
  319. ff_output = torch.cat(
  320. [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
  321. dim=self._chunk_dim,
  322. )
  323. else:
  324. ff_output = self.ff(norm_hidden_states)
  325. if self.use_ada_layer_norm_zero:
  326. ff_output = gate_mlp.unsqueeze(1) * ff_output
  327. hidden_states = ff_output + hidden_states
  328. return hidden_states, cache
  329. class ConditionalDecoder(nn.Module):
  330. def __init__(
  331. self,
  332. in_channels,
  333. out_channels,
  334. channels=(256, 256),
  335. dropout=0.05,
  336. attention_head_dim=64,
  337. n_blocks=1,
  338. num_mid_blocks=2,
  339. num_heads=4,
  340. act_fn="snake",
  341. ):
  342. """
  343. This decoder requires an input with the same shape of the target. So, if your text content
  344. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  345. """
  346. super().__init__()
  347. channels = tuple(channels)
  348. self.in_channels = in_channels
  349. self.out_channels = out_channels
  350. self.time_embeddings = SinusoidalPosEmb(in_channels)
  351. time_embed_dim = channels[0] * 4
  352. self.time_mlp = TimestepEmbedding(
  353. in_channels=in_channels,
  354. time_embed_dim=time_embed_dim,
  355. act_fn="silu",
  356. )
  357. self.down_blocks = nn.ModuleList([])
  358. self.mid_blocks = nn.ModuleList([])
  359. self.up_blocks = nn.ModuleList([])
  360. output_channel = in_channels
  361. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  362. input_channel = output_channel
  363. output_channel = channels[i]
  364. is_last = i == len(channels) - 1
  365. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  366. transformer_blocks = nn.ModuleList(
  367. [
  368. BasicTransformerBlock(
  369. dim=output_channel,
  370. num_attention_heads=num_heads,
  371. attention_head_dim=attention_head_dim,
  372. dropout=dropout,
  373. activation_fn=act_fn,
  374. )
  375. for _ in range(n_blocks)
  376. ]
  377. )
  378. downsample = (
  379. Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  380. )
  381. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  382. for _ in range(num_mid_blocks):
  383. input_channel = channels[-1]
  384. out_channels = channels[-1]
  385. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  386. transformer_blocks = nn.ModuleList(
  387. [
  388. BasicTransformerBlock(
  389. dim=output_channel,
  390. num_attention_heads=num_heads,
  391. attention_head_dim=attention_head_dim,
  392. dropout=dropout,
  393. activation_fn=act_fn,
  394. )
  395. for _ in range(n_blocks)
  396. ]
  397. )
  398. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  399. channels = channels[::-1] + (channels[0],)
  400. for i in range(len(channels) - 1):
  401. input_channel = channels[i] * 2
  402. output_channel = channels[i + 1]
  403. is_last = i == len(channels) - 2
  404. resnet = ResnetBlock1D(
  405. dim=input_channel,
  406. dim_out=output_channel,
  407. time_emb_dim=time_embed_dim,
  408. )
  409. transformer_blocks = nn.ModuleList(
  410. [
  411. BasicTransformerBlock(
  412. dim=output_channel,
  413. num_attention_heads=num_heads,
  414. attention_head_dim=attention_head_dim,
  415. dropout=dropout,
  416. activation_fn=act_fn,
  417. )
  418. for _ in range(n_blocks)
  419. ]
  420. )
  421. upsample = (
  422. Upsample1D(output_channel, use_conv_transpose=True)
  423. if not is_last
  424. else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  425. )
  426. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  427. self.final_block = Block1D(channels[-1], channels[-1])
  428. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  429. self.initialize_weights()
  430. def initialize_weights(self):
  431. for m in self.modules():
  432. if isinstance(m, nn.Conv1d):
  433. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  434. if m.bias is not None:
  435. nn.init.constant_(m.bias, 0)
  436. elif isinstance(m, nn.GroupNorm):
  437. nn.init.constant_(m.weight, 1)
  438. nn.init.constant_(m.bias, 0)
  439. elif isinstance(m, nn.Linear):
  440. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  441. if m.bias is not None:
  442. nn.init.constant_(m.bias, 0)
  443. def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  444. """Forward pass of the UNet1DConditional model.
  445. Args:
  446. x (torch.Tensor): shape (batch_size, in_channels, time)
  447. mask (_type_): shape (batch_size, 1, time)
  448. t (_type_): shape (batch_size)
  449. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  450. cond (_type_, optional): placeholder for future use. Defaults to None.
  451. Raises:
  452. ValueError: _description_
  453. ValueError: _description_
  454. Returns:
  455. _type_: _description_
  456. """
  457. t = self.time_embeddings(t).to(t.dtype)
  458. t = self.time_mlp(t)
  459. x = pack([x, mu], "b * t")[0]
  460. if spks is not None:
  461. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  462. x = pack([x, spks], "b * t")[0]
  463. if cond is not None:
  464. x = pack([x, cond], "b * t")[0]
  465. hiddens = []
  466. masks = [mask]
  467. for resnet, transformer_blocks, downsample in self.down_blocks:
  468. mask_down = masks[-1]
  469. x = resnet(x, mask_down, t)
  470. x = rearrange(x, "b c t -> b t c").contiguous()
  471. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  472. attn_mask = mask_to_bias(attn_mask, x.dtype)
  473. for transformer_block in transformer_blocks:
  474. x = transformer_block(
  475. hidden_states=x,
  476. attention_mask=attn_mask,
  477. timestep=t,
  478. )
  479. x = rearrange(x, "b t c -> b c t").contiguous()
  480. hiddens.append(x) # Save hidden states for skip connections
  481. x = downsample(x * mask_down)
  482. masks.append(mask_down[:, :, ::2])
  483. masks = masks[:-1]
  484. mask_mid = masks[-1]
  485. for resnet, transformer_blocks in self.mid_blocks:
  486. x = resnet(x, mask_mid, t)
  487. x = rearrange(x, "b c t -> b t c").contiguous()
  488. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  489. attn_mask = mask_to_bias(attn_mask, x.dtype)
  490. for transformer_block in transformer_blocks:
  491. x = transformer_block(
  492. hidden_states=x,
  493. attention_mask=attn_mask,
  494. timestep=t,
  495. )
  496. x = rearrange(x, "b t c -> b c t").contiguous()
  497. for resnet, transformer_blocks, upsample in self.up_blocks:
  498. mask_up = masks.pop()
  499. skip = hiddens.pop()
  500. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  501. x = resnet(x, mask_up, t)
  502. x = rearrange(x, "b c t -> b t c").contiguous()
  503. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  504. attn_mask = mask_to_bias(attn_mask, x.dtype)
  505. for transformer_block in transformer_blocks:
  506. x = transformer_block(
  507. hidden_states=x,
  508. attention_mask=attn_mask,
  509. timestep=t,
  510. )
  511. x = rearrange(x, "b t c -> b c t").contiguous()
  512. x = upsample(x * mask_up)
  513. x = self.final_block(x, mask_up)
  514. output = self.final_proj(x * mask_up)
  515. return output * mask
  516. class CausalConditionalDecoder(ConditionalDecoder):
  517. def __init__(
  518. self,
  519. in_channels,
  520. out_channels,
  521. channels=(256, 256),
  522. dropout=0.05,
  523. attention_head_dim=64,
  524. n_blocks=1,
  525. num_mid_blocks=2,
  526. num_heads=4,
  527. act_fn="snake",
  528. static_chunk_size=50,
  529. num_decoding_left_chunks=2,
  530. ):
  531. """
  532. This decoder requires an input with the same shape of the target. So, if your text content
  533. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  534. """
  535. torch.nn.Module.__init__(self)
  536. channels = tuple(channels)
  537. self.in_channels = in_channels
  538. self.out_channels = out_channels
  539. self.time_embeddings = SinusoidalPosEmb(in_channels)
  540. time_embed_dim = channels[0] * 4
  541. self.time_mlp = TimestepEmbedding(
  542. in_channels=in_channels,
  543. time_embed_dim=time_embed_dim,
  544. act_fn="silu",
  545. )
  546. self.static_chunk_size = static_chunk_size
  547. self.num_decoding_left_chunks = num_decoding_left_chunks
  548. self.down_blocks = nn.ModuleList([])
  549. self.mid_blocks = nn.ModuleList([])
  550. self.up_blocks = nn.ModuleList([])
  551. output_channel = in_channels
  552. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  553. input_channel = output_channel
  554. output_channel = channels[i]
  555. is_last = i == len(channels) - 1
  556. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  557. transformer_blocks = nn.ModuleList(
  558. [
  559. CausalBasicTransformerBlock(
  560. dim=output_channel,
  561. num_attention_heads=num_heads,
  562. attention_head_dim=attention_head_dim,
  563. dropout=dropout,
  564. activation_fn=act_fn,
  565. )
  566. for _ in range(n_blocks)
  567. ]
  568. )
  569. downsample = (
  570. Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
  571. )
  572. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  573. for _ in range(num_mid_blocks):
  574. input_channel = channels[-1]
  575. out_channels = channels[-1]
  576. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  577. transformer_blocks = nn.ModuleList(
  578. [
  579. CausalBasicTransformerBlock(
  580. dim=output_channel,
  581. num_attention_heads=num_heads,
  582. attention_head_dim=attention_head_dim,
  583. dropout=dropout,
  584. activation_fn=act_fn,
  585. )
  586. for _ in range(n_blocks)
  587. ]
  588. )
  589. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  590. channels = channels[::-1] + (channels[0],)
  591. for i in range(len(channels) - 1):
  592. input_channel = channels[i] * 2
  593. output_channel = channels[i + 1]
  594. is_last = i == len(channels) - 2
  595. resnet = CausalResnetBlock1D(
  596. dim=input_channel,
  597. dim_out=output_channel,
  598. time_emb_dim=time_embed_dim,
  599. )
  600. transformer_blocks = nn.ModuleList(
  601. [
  602. CausalBasicTransformerBlock(
  603. dim=output_channel,
  604. num_attention_heads=num_heads,
  605. attention_head_dim=attention_head_dim,
  606. dropout=dropout,
  607. activation_fn=act_fn,
  608. )
  609. for _ in range(n_blocks)
  610. ]
  611. )
  612. upsample = (
  613. Upsample1D(output_channel, use_conv_transpose=True)
  614. if not is_last
  615. else CausalConv1d(output_channel, output_channel, 3)
  616. )
  617. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  618. self.final_block = CausalBlock1D(channels[-1], channels[-1])
  619. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  620. self.initialize_weights()
  621. def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  622. """Forward pass of the UNet1DConditional model.
  623. Args:
  624. x (torch.Tensor): shape (batch_size, in_channels, time)
  625. mask (_type_): shape (batch_size, 1, time)
  626. t (_type_): shape (batch_size)
  627. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  628. cond (_type_, optional): placeholder for future use. Defaults to None.
  629. Raises:
  630. ValueError: _description_
  631. ValueError: _description_
  632. Returns:
  633. _type_: _description_
  634. """
  635. t = self.time_embeddings(t).to(t.dtype)
  636. t = self.time_mlp(t)
  637. x = pack([x, mu], "b * t")[0]
  638. if spks is not None:
  639. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  640. x = pack([x, spks], "b * t")[0]
  641. if cond is not None:
  642. x = pack([x, cond], "b * t")[0]
  643. hiddens = []
  644. masks = [mask]
  645. for resnet, transformer_blocks, downsample in self.down_blocks:
  646. mask_down = masks[-1]
  647. x, _, _ = resnet(x, mask_down, t)
  648. x = rearrange(x, "b c t -> b t c").contiguous()
  649. if streaming is True:
  650. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  651. else:
  652. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  653. attn_mask = mask_to_bias(attn_mask, x.dtype)
  654. for transformer_block in transformer_blocks:
  655. x, _ = transformer_block(
  656. hidden_states=x,
  657. attention_mask=attn_mask,
  658. timestep=t,
  659. )
  660. x = rearrange(x, "b t c -> b c t").contiguous()
  661. hiddens.append(x) # Save hidden states for skip connections
  662. x, _ = downsample(x * mask_down)
  663. masks.append(mask_down[:, :, ::2])
  664. masks = masks[:-1]
  665. mask_mid = masks[-1]
  666. for resnet, transformer_blocks in self.mid_blocks:
  667. x, _, _ = resnet(x, mask_mid, t)
  668. x = rearrange(x, "b c t -> b t c").contiguous()
  669. if streaming is True:
  670. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  671. else:
  672. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  673. attn_mask = mask_to_bias(attn_mask, x.dtype)
  674. for transformer_block in transformer_blocks:
  675. x, _ = transformer_block(
  676. hidden_states=x,
  677. attention_mask=attn_mask,
  678. timestep=t,
  679. )
  680. x = rearrange(x, "b t c -> b c t").contiguous()
  681. for resnet, transformer_blocks, upsample in self.up_blocks:
  682. mask_up = masks.pop()
  683. skip = hiddens.pop()
  684. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  685. x, _, _ = resnet(x, mask_up, t)
  686. x = rearrange(x, "b c t -> b t c").contiguous()
  687. if streaming is True:
  688. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
  689. else:
  690. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  691. attn_mask = mask_to_bias(attn_mask, x.dtype)
  692. for transformer_block in transformer_blocks:
  693. x, _ = transformer_block(
  694. hidden_states=x,
  695. attention_mask=attn_mask,
  696. timestep=t,
  697. )
  698. x = rearrange(x, "b t c -> b c t").contiguous()
  699. x, _ = upsample(x * mask_up)
  700. x, _ = self.final_block(x, mask_up)
  701. output = self.final_proj(x * mask_up)
  702. return output * mask
  703. def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
  704. down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  705. down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  706. mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  707. mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  708. up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
  709. up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
  710. final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
  711. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  712. """Forward pass of the UNet1DConditional model.
  713. Args:
  714. x (torch.Tensor): shape (batch_size, in_channels, time)
  715. mask (_type_): shape (batch_size, 1, time)
  716. t (_type_): shape (batch_size)
  717. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  718. cond (_type_, optional): placeholder for future use. Defaults to None.
  719. Raises:
  720. ValueError: _description_
  721. ValueError: _description_
  722. Returns:
  723. _type_: _description_
  724. """
  725. t = self.time_embeddings(t).to(t.dtype)
  726. t = self.time_mlp(t)
  727. x = pack([x, mu], "b * t")[0]
  728. if spks is not None:
  729. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  730. x = pack([x, spks], "b * t")[0]
  731. if cond is not None:
  732. x = pack([x, cond], "b * t")[0]
  733. hiddens = []
  734. masks = [mask]
  735. down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
  736. mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
  737. up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
  738. for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
  739. mask_down = masks[-1]
  740. x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
  741. x = rearrange(x, "b c t -> b t c").contiguous()
  742. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
  743. attn_mask = mask_to_bias(attn_mask, x.dtype)
  744. for i, transformer_block in enumerate(transformer_blocks):
  745. x, down_blocks_kv_cache_new[index, i] = transformer_block(
  746. hidden_states=x,
  747. attention_mask=attn_mask,
  748. timestep=t,
  749. cache=down_blocks_kv_cache[index, i],
  750. )
  751. x = rearrange(x, "b t c -> b c t").contiguous()
  752. hiddens.append(x) # Save hidden states for skip connections
  753. x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
  754. masks.append(mask_down[:, :, ::2])
  755. masks = masks[:-1]
  756. mask_mid = masks[-1]
  757. for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
  758. x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
  759. x = rearrange(x, "b c t -> b t c").contiguous()
  760. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
  761. attn_mask = mask_to_bias(attn_mask, x.dtype)
  762. for i, transformer_block in enumerate(transformer_blocks):
  763. x, mid_blocks_kv_cache_new[index, i] = transformer_block(
  764. hidden_states=x,
  765. attention_mask=attn_mask,
  766. timestep=t,
  767. cache=mid_blocks_kv_cache[index, i]
  768. )
  769. x = rearrange(x, "b t c -> b c t").contiguous()
  770. for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
  771. mask_up = masks.pop()
  772. skip = hiddens.pop()
  773. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  774. x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
  775. x = rearrange(x, "b c t -> b t c").contiguous()
  776. attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
  777. attn_mask = mask_to_bias(attn_mask, x.dtype)
  778. for i, transformer_block in enumerate(transformer_blocks):
  779. x, up_blocks_kv_cache_new[index, i] = transformer_block(
  780. hidden_states=x,
  781. attention_mask=attn_mask,
  782. timestep=t,
  783. cache=up_blocks_kv_cache[index, i]
  784. )
  785. x = rearrange(x, "b t c -> b c t").contiguous()
  786. x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
  787. x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
  788. output = self.final_proj(x * mask_up)
  789. return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache