decoder.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Tuple
  15. import torch
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from einops import pack, rearrange, repeat
  19. from cosyvoice.utils.common import mask_to_bias
  20. from cosyvoice.utils.mask import add_optional_chunk_mask
  21. from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
  22. from matcha.models.components.transformer import BasicTransformerBlock
  23. class Transpose(torch.nn.Module):
  24. def __init__(self, dim0: int, dim1: int):
  25. super().__init__()
  26. self.dim0 = dim0
  27. self.dim1 = dim1
  28. def forward(self, x: torch.Tensor) -> torch.Tensor:
  29. x = torch.transpose(x, self.dim0, self.dim1)
  30. return x
  31. class CausalConv1d(torch.nn.Conv1d):
  32. def __init__(
  33. self,
  34. in_channels: int,
  35. out_channels: int,
  36. kernel_size: int,
  37. stride: int = 1,
  38. dilation: int = 1,
  39. groups: int = 1,
  40. bias: bool = True,
  41. padding_mode: str = 'zeros',
  42. device=None,
  43. dtype=None
  44. ) -> None:
  45. super(CausalConv1d, self).__init__(in_channels, out_channels,
  46. kernel_size, stride,
  47. padding=0, dilation=dilation,
  48. groups=groups, bias=bias,
  49. padding_mode=padding_mode,
  50. device=device, dtype=dtype)
  51. assert stride == 1
  52. self.causal_padding = kernel_size - 1
  53. def forward(self, x: torch.Tensor) -> torch.Tensor:
  54. x = F.pad(x, (self.causal_padding, 0), value=0.0)
  55. x = super(CausalConv1d, self).forward(x)
  56. return x
  57. class CausalBlock1D(Block1D):
  58. def __init__(self, dim: int, dim_out: int):
  59. super(CausalBlock1D, self).__init__(dim, dim_out)
  60. self.block = torch.nn.Sequential(
  61. CausalConv1d(dim, dim_out, 3),
  62. Transpose(1, 2),
  63. nn.LayerNorm(dim_out),
  64. Transpose(1, 2),
  65. nn.Mish(),
  66. )
  67. def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  68. output = self.block(x * mask)
  69. return output * mask
  70. class CausalResnetBlock1D(ResnetBlock1D):
  71. def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
  72. super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
  73. self.block1 = CausalBlock1D(dim, dim_out)
  74. self.block2 = CausalBlock1D(dim_out, dim_out)
  75. class ConditionalDecoder(nn.Module):
  76. def __init__(
  77. self,
  78. in_channels,
  79. out_channels,
  80. channels=(256, 256),
  81. dropout=0.05,
  82. attention_head_dim=64,
  83. n_blocks=1,
  84. num_mid_blocks=2,
  85. num_heads=4,
  86. act_fn="snake",
  87. ):
  88. """
  89. This decoder requires an input with the same shape of the target. So, if your text content
  90. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  91. """
  92. super().__init__()
  93. channels = tuple(channels)
  94. self.in_channels = in_channels
  95. self.out_channels = out_channels
  96. self.time_embeddings = SinusoidalPosEmb(in_channels)
  97. time_embed_dim = channels[0] * 4
  98. self.time_mlp = TimestepEmbedding(
  99. in_channels=in_channels,
  100. time_embed_dim=time_embed_dim,
  101. act_fn="silu",
  102. )
  103. self.down_blocks = nn.ModuleList([])
  104. self.mid_blocks = nn.ModuleList([])
  105. self.up_blocks = nn.ModuleList([])
  106. output_channel = in_channels
  107. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  108. input_channel = output_channel
  109. output_channel = channels[i]
  110. is_last = i == len(channels) - 1
  111. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  112. transformer_blocks = nn.ModuleList(
  113. [
  114. BasicTransformerBlock(
  115. dim=output_channel,
  116. num_attention_heads=num_heads,
  117. attention_head_dim=attention_head_dim,
  118. dropout=dropout,
  119. activation_fn=act_fn,
  120. )
  121. for _ in range(n_blocks)
  122. ]
  123. )
  124. downsample = (
  125. Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  126. )
  127. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  128. for _ in range(num_mid_blocks):
  129. input_channel = channels[-1]
  130. out_channels = channels[-1]
  131. resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  132. transformer_blocks = nn.ModuleList(
  133. [
  134. BasicTransformerBlock(
  135. dim=output_channel,
  136. num_attention_heads=num_heads,
  137. attention_head_dim=attention_head_dim,
  138. dropout=dropout,
  139. activation_fn=act_fn,
  140. )
  141. for _ in range(n_blocks)
  142. ]
  143. )
  144. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  145. channels = channels[::-1] + (channels[0],)
  146. for i in range(len(channels) - 1):
  147. input_channel = channels[i] * 2
  148. output_channel = channels[i + 1]
  149. is_last = i == len(channels) - 2
  150. resnet = ResnetBlock1D(
  151. dim=input_channel,
  152. dim_out=output_channel,
  153. time_emb_dim=time_embed_dim,
  154. )
  155. transformer_blocks = nn.ModuleList(
  156. [
  157. BasicTransformerBlock(
  158. dim=output_channel,
  159. num_attention_heads=num_heads,
  160. attention_head_dim=attention_head_dim,
  161. dropout=dropout,
  162. activation_fn=act_fn,
  163. )
  164. for _ in range(n_blocks)
  165. ]
  166. )
  167. upsample = (
  168. Upsample1D(output_channel, use_conv_transpose=True)
  169. if not is_last
  170. else nn.Conv1d(output_channel, output_channel, 3, padding=1)
  171. )
  172. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  173. self.final_block = Block1D(channels[-1], channels[-1])
  174. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  175. self.initialize_weights()
  176. def initialize_weights(self):
  177. for m in self.modules():
  178. if isinstance(m, nn.Conv1d):
  179. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  180. if m.bias is not None:
  181. nn.init.constant_(m.bias, 0)
  182. elif isinstance(m, nn.GroupNorm):
  183. nn.init.constant_(m.weight, 1)
  184. nn.init.constant_(m.bias, 0)
  185. elif isinstance(m, nn.Linear):
  186. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  187. if m.bias is not None:
  188. nn.init.constant_(m.bias, 0)
  189. def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  190. """Forward pass of the UNet1DConditional model.
  191. Args:
  192. x (torch.Tensor): shape (batch_size, in_channels, time)
  193. mask (_type_): shape (batch_size, 1, time)
  194. t (_type_): shape (batch_size)
  195. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  196. cond (_type_, optional): placeholder for future use. Defaults to None.
  197. Raises:
  198. ValueError: _description_
  199. ValueError: _description_
  200. Returns:
  201. _type_: _description_
  202. """
  203. t = self.time_embeddings(t).to(t.dtype)
  204. t = self.time_mlp(t)
  205. x = pack([x, mu], "b * t")[0]
  206. if spks is not None:
  207. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  208. x = pack([x, spks], "b * t")[0]
  209. if cond is not None:
  210. x = pack([x, cond], "b * t")[0]
  211. hiddens = []
  212. masks = [mask]
  213. for resnet, transformer_blocks, downsample in self.down_blocks:
  214. mask_down = masks[-1]
  215. x = resnet(x, mask_down, t)
  216. x = rearrange(x, "b c t -> b t c").contiguous()
  217. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  218. attn_mask = mask_to_bias(attn_mask, x.dtype)
  219. for transformer_block in transformer_blocks:
  220. x = transformer_block(
  221. hidden_states=x,
  222. attention_mask=attn_mask,
  223. timestep=t,
  224. )
  225. x = rearrange(x, "b t c -> b c t").contiguous()
  226. hiddens.append(x) # Save hidden states for skip connections
  227. x = downsample(x * mask_down)
  228. masks.append(mask_down[:, :, ::2])
  229. masks = masks[:-1]
  230. mask_mid = masks[-1]
  231. for resnet, transformer_blocks in self.mid_blocks:
  232. x = resnet(x, mask_mid, t)
  233. x = rearrange(x, "b c t -> b t c").contiguous()
  234. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  235. attn_mask = mask_to_bias(attn_mask, x.dtype)
  236. for transformer_block in transformer_blocks:
  237. x = transformer_block(
  238. hidden_states=x,
  239. attention_mask=attn_mask,
  240. timestep=t,
  241. )
  242. x = rearrange(x, "b t c -> b c t").contiguous()
  243. for resnet, transformer_blocks, upsample in self.up_blocks:
  244. mask_up = masks.pop()
  245. skip = hiddens.pop()
  246. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  247. x = resnet(x, mask_up, t)
  248. x = rearrange(x, "b c t -> b t c").contiguous()
  249. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  250. attn_mask = mask_to_bias(attn_mask, x.dtype)
  251. for transformer_block in transformer_blocks:
  252. x = transformer_block(
  253. hidden_states=x,
  254. attention_mask=attn_mask,
  255. timestep=t,
  256. )
  257. x = rearrange(x, "b t c -> b c t").contiguous()
  258. x = upsample(x * mask_up)
  259. x = self.final_block(x, mask_up)
  260. output = self.final_proj(x * mask_up)
  261. return output * mask
  262. class CausalConditionalDecoder(ConditionalDecoder):
  263. def __init__(
  264. self,
  265. in_channels,
  266. out_channels,
  267. channels=(256, 256),
  268. dropout=0.05,
  269. attention_head_dim=64,
  270. n_blocks=1,
  271. num_mid_blocks=2,
  272. num_heads=4,
  273. act_fn="snake",
  274. static_chunk_size=50,
  275. num_decoding_left_chunks=2,
  276. ):
  277. """
  278. This decoder requires an input with the same shape of the target. So, if your text content
  279. is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
  280. """
  281. torch.nn.Module.__init__(self)
  282. channels = tuple(channels)
  283. self.in_channels = in_channels
  284. self.out_channels = out_channels
  285. self.time_embeddings = SinusoidalPosEmb(in_channels)
  286. time_embed_dim = channels[0] * 4
  287. self.time_mlp = TimestepEmbedding(
  288. in_channels=in_channels,
  289. time_embed_dim=time_embed_dim,
  290. act_fn="silu",
  291. )
  292. self.static_chunk_size = static_chunk_size
  293. self.num_decoding_left_chunks = num_decoding_left_chunks
  294. self.down_blocks = nn.ModuleList([])
  295. self.mid_blocks = nn.ModuleList([])
  296. self.up_blocks = nn.ModuleList([])
  297. output_channel = in_channels
  298. for i in range(len(channels)): # pylint: disable=consider-using-enumerate
  299. input_channel = output_channel
  300. output_channel = channels[i]
  301. is_last = i == len(channels) - 1
  302. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  303. transformer_blocks = nn.ModuleList(
  304. [
  305. BasicTransformerBlock(
  306. dim=output_channel,
  307. num_attention_heads=num_heads,
  308. attention_head_dim=attention_head_dim,
  309. dropout=dropout,
  310. activation_fn=act_fn,
  311. )
  312. for _ in range(n_blocks)
  313. ]
  314. )
  315. downsample = (
  316. Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
  317. )
  318. self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
  319. for _ in range(num_mid_blocks):
  320. input_channel = channels[-1]
  321. out_channels = channels[-1]
  322. resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
  323. transformer_blocks = nn.ModuleList(
  324. [
  325. BasicTransformerBlock(
  326. dim=output_channel,
  327. num_attention_heads=num_heads,
  328. attention_head_dim=attention_head_dim,
  329. dropout=dropout,
  330. activation_fn=act_fn,
  331. )
  332. for _ in range(n_blocks)
  333. ]
  334. )
  335. self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
  336. channels = channels[::-1] + (channels[0],)
  337. for i in range(len(channels) - 1):
  338. input_channel = channels[i] * 2
  339. output_channel = channels[i + 1]
  340. is_last = i == len(channels) - 2
  341. resnet = CausalResnetBlock1D(
  342. dim=input_channel,
  343. dim_out=output_channel,
  344. time_emb_dim=time_embed_dim,
  345. )
  346. transformer_blocks = nn.ModuleList(
  347. [
  348. BasicTransformerBlock(
  349. dim=output_channel,
  350. num_attention_heads=num_heads,
  351. attention_head_dim=attention_head_dim,
  352. dropout=dropout,
  353. activation_fn=act_fn,
  354. )
  355. for _ in range(n_blocks)
  356. ]
  357. )
  358. upsample = (
  359. Upsample1D(output_channel, use_conv_transpose=True)
  360. if not is_last
  361. else CausalConv1d(output_channel, output_channel, 3)
  362. )
  363. self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
  364. self.final_block = CausalBlock1D(channels[-1], channels[-1])
  365. self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
  366. self.initialize_weights()
  367. def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
  368. """Forward pass of the UNet1DConditional model.
  369. Args:
  370. x (torch.Tensor): shape (batch_size, in_channels, time)
  371. mask (_type_): shape (batch_size, 1, time)
  372. t (_type_): shape (batch_size)
  373. spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
  374. cond (_type_, optional): placeholder for future use. Defaults to None.
  375. Raises:
  376. ValueError: _description_
  377. ValueError: _description_
  378. Returns:
  379. _type_: _description_
  380. """
  381. t = self.time_embeddings(t).to(t.dtype)
  382. t = self.time_mlp(t)
  383. x = pack([x, mu], "b * t")[0]
  384. if spks is not None:
  385. spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
  386. x = pack([x, spks], "b * t")[0]
  387. if cond is not None:
  388. x = pack([x, cond], "b * t")[0]
  389. hiddens = []
  390. masks = [mask]
  391. for resnet, transformer_blocks, downsample in self.down_blocks:
  392. mask_down = masks[-1]
  393. x = resnet(x, mask_down, t)
  394. x = rearrange(x, "b c t -> b t c").contiguous()
  395. if streaming is True:
  396. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
  397. else:
  398. attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  399. attn_mask = mask_to_bias(attn_mask, x.dtype)
  400. for transformer_block in transformer_blocks:
  401. x = transformer_block(
  402. hidden_states=x,
  403. attention_mask=attn_mask,
  404. timestep=t,
  405. )
  406. x = rearrange(x, "b t c -> b c t").contiguous()
  407. hiddens.append(x) # Save hidden states for skip connections
  408. x = downsample(x * mask_down)
  409. masks.append(mask_down[:, :, ::2])
  410. masks = masks[:-1]
  411. mask_mid = masks[-1]
  412. for resnet, transformer_blocks in self.mid_blocks:
  413. x = resnet(x, mask_mid, t)
  414. x = rearrange(x, "b c t -> b t c").contiguous()
  415. if streaming is True:
  416. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
  417. else:
  418. attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  419. attn_mask = mask_to_bias(attn_mask, x.dtype)
  420. for transformer_block in transformer_blocks:
  421. x = transformer_block(
  422. hidden_states=x,
  423. attention_mask=attn_mask,
  424. timestep=t,
  425. )
  426. x = rearrange(x, "b t c -> b c t").contiguous()
  427. for resnet, transformer_blocks, upsample in self.up_blocks:
  428. mask_up = masks.pop()
  429. skip = hiddens.pop()
  430. x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
  431. x = resnet(x, mask_up, t)
  432. x = rearrange(x, "b c t -> b t c").contiguous()
  433. if streaming is True:
  434. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
  435. else:
  436. attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
  437. attn_mask = mask_to_bias(attn_mask, x.dtype)
  438. for transformer_block in transformer_blocks:
  439. x = transformer_block(
  440. hidden_states=x,
  441. attention_mask=attn_mask,
  442. timestep=t,
  443. )
  444. x = rearrange(x, "b t c -> b c t").contiguous()
  445. x = upsample(x * mask_up)
  446. x = self.final_block(x, mask_up)
  447. output = self.final_proj(x * mask_up)
  448. return output * mask