modules.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. """
  2. ein notation:
  3. b - batch
  4. n - sequence
  5. nt - text sequence
  6. nw - raw wave length
  7. d - dimension
  8. """
  9. from __future__ import annotations
  10. from typing import Optional
  11. import math
  12. import torch
  13. from torch import nn
  14. import torch.nn.functional as F
  15. import torchaudio
  16. from x_transformers.x_transformers import apply_rotary_pos_emb
  17. # raw wav to mel spec
  18. class MelSpec(nn.Module):
  19. def __init__(
  20. self,
  21. filter_length=1024,
  22. hop_length=256,
  23. win_length=1024,
  24. n_mel_channels=100,
  25. target_sample_rate=24_000,
  26. normalize=False,
  27. power=1,
  28. norm=None,
  29. center=True,
  30. ):
  31. super().__init__()
  32. self.n_mel_channels = n_mel_channels
  33. self.mel_stft = torchaudio.transforms.MelSpectrogram(
  34. sample_rate=target_sample_rate,
  35. n_fft=filter_length,
  36. win_length=win_length,
  37. hop_length=hop_length,
  38. n_mels=n_mel_channels,
  39. power=power,
  40. center=center,
  41. normalized=normalize,
  42. norm=norm,
  43. )
  44. self.register_buffer("dummy", torch.tensor(0), persistent=False)
  45. def forward(self, inp):
  46. if len(inp.shape) == 3:
  47. inp = inp.squeeze(1) # 'b 1 nw -> b nw'
  48. assert len(inp.shape) == 2
  49. if self.dummy.device != inp.device:
  50. self.to(inp.device)
  51. mel = self.mel_stft(inp)
  52. mel = mel.clamp(min=1e-5).log()
  53. return mel
  54. # sinusoidal position embedding
  55. class SinusPositionEmbedding(nn.Module):
  56. def __init__(self, dim):
  57. super().__init__()
  58. self.dim = dim
  59. def forward(self, x, scale=1000):
  60. device = x.device
  61. half_dim = self.dim // 2
  62. emb = math.log(10000) / (half_dim - 1)
  63. emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
  64. emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
  65. emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
  66. return emb
  67. # convolutional position embedding
  68. class ConvPositionEmbedding(nn.Module):
  69. def __init__(self, dim, kernel_size=31, groups=16):
  70. super().__init__()
  71. assert kernel_size % 2 != 0
  72. self.conv1d = nn.Sequential(
  73. nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
  74. nn.Mish(),
  75. nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
  76. nn.Mish(),
  77. )
  78. def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
  79. if mask is not None:
  80. mask = mask[..., None]
  81. x = x.masked_fill(~mask, 0.0)
  82. x = x.permute(0, 2, 1)
  83. x = self.conv1d(x)
  84. out = x.permute(0, 2, 1)
  85. if mask is not None:
  86. out = out.masked_fill(~mask, 0.0)
  87. return out
  88. class CausalConvPositionEmbedding(nn.Module):
  89. def __init__(self, dim, kernel_size=31, groups=16):
  90. super().__init__()
  91. assert kernel_size % 2 != 0
  92. self.kernel_size = kernel_size
  93. self.conv1 = nn.Sequential(
  94. nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
  95. nn.Mish(),
  96. )
  97. self.conv2 = nn.Sequential(
  98. nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
  99. nn.Mish(),
  100. )
  101. def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
  102. if mask is not None:
  103. mask = mask[..., None]
  104. x = x.masked_fill(~mask, 0.0)
  105. x = x.permute(0, 2, 1)
  106. x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
  107. x = self.conv1(x)
  108. x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
  109. x = self.conv2(x)
  110. out = x.permute(0, 2, 1)
  111. if mask is not None:
  112. out = out.masked_fill(~mask, 0.0)
  113. return out
  114. # rotary positional embedding related
  115. def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
  116. # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
  117. # has some connection to NTK literature
  118. # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
  119. # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
  120. theta *= theta_rescale_factor ** (dim / (dim - 2))
  121. freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
  122. t = torch.arange(end, device=freqs.device) # type: ignore
  123. freqs = torch.outer(t, freqs).float() # type: ignore
  124. freqs_cos = torch.cos(freqs) # real part
  125. freqs_sin = torch.sin(freqs) # imaginary part
  126. return torch.cat([freqs_cos, freqs_sin], dim=-1)
  127. def get_pos_embed_indices(start, length, max_pos, scale=1.0):
  128. # length = length if isinstance(length, int) else length.max()
  129. scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
  130. pos = (
  131. start.unsqueeze(1)
  132. + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
  133. )
  134. # avoid extra long error.
  135. pos = torch.where(pos < max_pos, pos, max_pos - 1)
  136. return pos
  137. # Global Response Normalization layer (Instance Normalization ?)
  138. class GRN(nn.Module):
  139. def __init__(self, dim):
  140. super().__init__()
  141. self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
  142. self.beta = nn.Parameter(torch.zeros(1, 1, dim))
  143. def forward(self, x):
  144. Gx = torch.norm(x, p=2, dim=1, keepdim=True)
  145. Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
  146. return self.gamma * (x * Nx) + self.beta + x
  147. # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
  148. # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
  149. class ConvNeXtV2Block(nn.Module):
  150. def __init__(
  151. self,
  152. dim: int,
  153. intermediate_dim: int,
  154. dilation: int = 1,
  155. ):
  156. super().__init__()
  157. padding = (dilation * (7 - 1)) // 2
  158. self.dwconv = nn.Conv1d(
  159. dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
  160. ) # depthwise conv
  161. self.norm = nn.LayerNorm(dim, eps=1e-6)
  162. self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
  163. self.act = nn.GELU()
  164. self.grn = GRN(intermediate_dim)
  165. self.pwconv2 = nn.Linear(intermediate_dim, dim)
  166. def forward(self, x: torch.Tensor) -> torch.Tensor:
  167. residual = x
  168. x = x.transpose(1, 2) # b n d -> b d n
  169. x = self.dwconv(x)
  170. x = x.transpose(1, 2) # b d n -> b n d
  171. x = self.norm(x)
  172. x = self.pwconv1(x)
  173. x = self.act(x)
  174. x = self.grn(x)
  175. x = self.pwconv2(x)
  176. return residual + x
  177. # AdaLayerNormZero
  178. # return with modulated x for attn input, and params for later mlp modulation
  179. class AdaLayerNormZero(nn.Module):
  180. def __init__(self, dim):
  181. super().__init__()
  182. self.silu = nn.SiLU()
  183. self.linear = nn.Linear(dim, dim * 6)
  184. self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
  185. def forward(self, x, emb=None):
  186. emb = self.linear(self.silu(emb))
  187. shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
  188. x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
  189. return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
  190. # AdaLayerNormZero for final layer
  191. # return only with modulated x for attn input, cuz no more mlp modulation
  192. class AdaLayerNormZero_Final(nn.Module):
  193. def __init__(self, dim):
  194. super().__init__()
  195. self.silu = nn.SiLU()
  196. self.linear = nn.Linear(dim, dim * 2)
  197. self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
  198. def forward(self, x, emb):
  199. emb = self.linear(self.silu(emb))
  200. scale, shift = torch.chunk(emb, 2, dim=1)
  201. x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
  202. return x
  203. # FeedForward
  204. class FeedForward(nn.Module):
  205. def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
  206. super().__init__()
  207. inner_dim = int(dim * mult)
  208. dim_out = dim_out if dim_out is not None else dim
  209. activation = nn.GELU(approximate=approximate)
  210. project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
  211. self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
  212. def forward(self, x):
  213. return self.ff(x)
  214. # Attention with possible joint part
  215. # modified from diffusers/src/diffusers/models/attention_processor.py
  216. class Attention(nn.Module):
  217. def __init__(
  218. self,
  219. processor: JointAttnProcessor | AttnProcessor,
  220. dim: int,
  221. heads: int = 8,
  222. dim_head: int = 64,
  223. dropout: float = 0.0,
  224. context_dim: Optional[int] = None, # if not None -> joint attention
  225. context_pre_only=None,
  226. ):
  227. super().__init__()
  228. if not hasattr(F, "scaled_dot_product_attention"):
  229. raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
  230. self.processor = processor
  231. self.dim = dim
  232. self.heads = heads
  233. self.inner_dim = dim_head * heads
  234. self.dropout = dropout
  235. self.context_dim = context_dim
  236. self.context_pre_only = context_pre_only
  237. self.to_q = nn.Linear(dim, self.inner_dim)
  238. self.to_k = nn.Linear(dim, self.inner_dim)
  239. self.to_v = nn.Linear(dim, self.inner_dim)
  240. if self.context_dim is not None:
  241. self.to_k_c = nn.Linear(context_dim, self.inner_dim)
  242. self.to_v_c = nn.Linear(context_dim, self.inner_dim)
  243. if self.context_pre_only is not None:
  244. self.to_q_c = nn.Linear(context_dim, self.inner_dim)
  245. self.to_out = nn.ModuleList([])
  246. self.to_out.append(nn.Linear(self.inner_dim, dim))
  247. self.to_out.append(nn.Dropout(dropout))
  248. if self.context_pre_only is not None and not self.context_pre_only:
  249. self.to_out_c = nn.Linear(self.inner_dim, dim)
  250. def forward(
  251. self,
  252. x: float["b n d"], # noised input x # noqa: F722
  253. c: float["b n d"] = None, # context c # noqa: F722
  254. mask: bool["b n"] | None = None, # noqa: F722
  255. rope=None, # rotary position embedding for x
  256. c_rope=None, # rotary position embedding for c
  257. ) -> torch.Tensor:
  258. if c is not None:
  259. return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
  260. else:
  261. return self.processor(self, x, mask=mask, rope=rope)
  262. # Attention processor
  263. class AttnProcessor:
  264. def __init__(self):
  265. pass
  266. def __call__(
  267. self,
  268. attn: Attention,
  269. x: float["b n d"], # noised input x # noqa: F722
  270. mask: bool["b n"] | None = None, # noqa: F722
  271. rope=None, # rotary position embedding
  272. ) -> torch.FloatTensor:
  273. batch_size = x.shape[0]
  274. # `sample` projections.
  275. query = attn.to_q(x)
  276. key = attn.to_k(x)
  277. value = attn.to_v(x)
  278. # apply rotary position embedding
  279. if rope is not None:
  280. freqs, xpos_scale = rope
  281. q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
  282. query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
  283. key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
  284. # attention
  285. inner_dim = key.shape[-1]
  286. head_dim = inner_dim // attn.heads
  287. query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  288. key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  289. value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  290. # mask. e.g. inference got a batch with different target durations, mask out the padding
  291. if mask is not None:
  292. attn_mask = mask
  293. if attn_mask.dim() == 2:
  294. attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
  295. attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
  296. else:
  297. attn_mask = None
  298. x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
  299. x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
  300. x = x.to(query.dtype)
  301. # linear proj
  302. x = attn.to_out[0](x)
  303. # dropout
  304. x = attn.to_out[1](x)
  305. if mask is not None:
  306. if mask.dim() == 2:
  307. mask = mask.unsqueeze(-1)
  308. else:
  309. mask = mask[:, 0, -1].unsqueeze(-1)
  310. x = x.masked_fill(~mask, 0.0)
  311. return x
  312. # Joint Attention processor for MM-DiT
  313. # modified from diffusers/src/diffusers/models/attention_processor.py
  314. class JointAttnProcessor:
  315. def __init__(self):
  316. pass
  317. def __call__(
  318. self,
  319. attn: Attention,
  320. x: float["b n d"], # noised input x # noqa: F722
  321. c: float["b nt d"] = None, # context c, here text # noqa: F722
  322. mask: bool["b n"] | None = None, # noqa: F722
  323. rope=None, # rotary position embedding for x
  324. c_rope=None, # rotary position embedding for c
  325. ) -> torch.FloatTensor:
  326. residual = x
  327. batch_size = c.shape[0]
  328. # `sample` projections.
  329. query = attn.to_q(x)
  330. key = attn.to_k(x)
  331. value = attn.to_v(x)
  332. # `context` projections.
  333. c_query = attn.to_q_c(c)
  334. c_key = attn.to_k_c(c)
  335. c_value = attn.to_v_c(c)
  336. # apply rope for context and noised input independently
  337. if rope is not None:
  338. freqs, xpos_scale = rope
  339. q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
  340. query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
  341. key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
  342. if c_rope is not None:
  343. freqs, xpos_scale = c_rope
  344. q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
  345. c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
  346. c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
  347. # attention
  348. query = torch.cat([query, c_query], dim=1)
  349. key = torch.cat([key, c_key], dim=1)
  350. value = torch.cat([value, c_value], dim=1)
  351. inner_dim = key.shape[-1]
  352. head_dim = inner_dim // attn.heads
  353. query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  354. key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  355. value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  356. # mask. e.g. inference got a batch with different target durations, mask out the padding
  357. if mask is not None:
  358. attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
  359. attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
  360. attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
  361. else:
  362. attn_mask = None
  363. x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
  364. x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
  365. x = x.to(query.dtype)
  366. # Split the attention outputs.
  367. x, c = (
  368. x[:, : residual.shape[1]],
  369. x[:, residual.shape[1]:],
  370. )
  371. # linear proj
  372. x = attn.to_out[0](x)
  373. # dropout
  374. x = attn.to_out[1](x)
  375. if not attn.context_pre_only:
  376. c = attn.to_out_c(c)
  377. if mask is not None:
  378. mask = mask.unsqueeze(-1)
  379. x = x.masked_fill(~mask, 0.0)
  380. # c = c.masked_fill(~mask, 0.) # no mask for c (text)
  381. return x, c
  382. # DiT Block
  383. class DiTBlock(nn.Module):
  384. def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
  385. super().__init__()
  386. self.attn_norm = AdaLayerNormZero(dim)
  387. self.attn = Attention(
  388. processor=AttnProcessor(),
  389. dim=dim,
  390. heads=heads,
  391. dim_head=dim_head,
  392. dropout=dropout,
  393. )
  394. self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
  395. self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
  396. def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
  397. # pre-norm & modulation for attention input
  398. norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
  399. # attention
  400. attn_output = self.attn(x=norm, mask=mask, rope=rope)
  401. # process attention output for input x
  402. x = x + gate_msa.unsqueeze(1) * attn_output
  403. ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
  404. ff_output = self.ff(ff_norm)
  405. x = x + gate_mlp.unsqueeze(1) * ff_output
  406. return x
  407. # MMDiT Block https://arxiv.org/abs/2403.03206
  408. class MMDiTBlock(nn.Module):
  409. r"""
  410. modified from diffusers/src/diffusers/models/attention.py
  411. notes.
  412. _c: context related. text, cond, etc. (left part in sd3 fig2.b)
  413. _x: noised input related. (right part)
  414. context_pre_only: last layer only do prenorm + modulation cuz no more ffn
  415. """
  416. def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
  417. super().__init__()
  418. self.context_pre_only = context_pre_only
  419. self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
  420. self.attn_norm_x = AdaLayerNormZero(dim)
  421. self.attn = Attention(
  422. processor=JointAttnProcessor(),
  423. dim=dim,
  424. heads=heads,
  425. dim_head=dim_head,
  426. dropout=dropout,
  427. context_dim=dim,
  428. context_pre_only=context_pre_only,
  429. )
  430. if not context_pre_only:
  431. self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
  432. self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
  433. else:
  434. self.ff_norm_c = None
  435. self.ff_c = None
  436. self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
  437. self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
  438. def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
  439. # pre-norm & modulation for attention input
  440. if self.context_pre_only:
  441. norm_c = self.attn_norm_c(c, t)
  442. else:
  443. norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
  444. norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
  445. # attention
  446. x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
  447. # process attention output for context c
  448. if self.context_pre_only:
  449. c = None
  450. else: # if not last layer
  451. c = c + c_gate_msa.unsqueeze(1) * c_attn_output
  452. norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
  453. c_ff_output = self.ff_c(norm_c)
  454. c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
  455. # process attention output for input x
  456. x = x + x_gate_msa.unsqueeze(1) * x_attn_output
  457. norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
  458. x_ff_output = self.ff_x(norm_x)
  459. x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
  460. return c, x
  461. # time step conditioning embedding
  462. class TimestepEmbedding(nn.Module):
  463. def __init__(self, dim, freq_embed_dim=256):
  464. super().__init__()
  465. self.time_embed = SinusPositionEmbedding(freq_embed_dim)
  466. self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
  467. def forward(self, timestep: float["b"]): # noqa: F821
  468. time_hidden = self.time_embed(timestep)
  469. time_hidden = time_hidden.to(timestep.dtype)
  470. time = self.time_mlp(time_hidden) # b d
  471. return time