|
|
@@ -49,7 +49,7 @@ class CausalBlock1D(Block1D):
|
|
|
|
|
|
|
|
|
class CausalResnetBlock1D(ResnetBlock1D):
|
|
|
- def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
|
|
|
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
|
|
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
|
|
self.block1 = CausalBlock1D(dim, dim_out)
|
|
|
self.block2 = CausalBlock1D(dim_out, dim_out)
|
|
|
@@ -70,12 +70,11 @@ class CausalConv1d(torch.nn.Conv1d):
|
|
|
dtype=None
|
|
|
) -> None:
|
|
|
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
|
|
- kernel_size, stride,
|
|
|
- padding=0, dilation=dilation,
|
|
|
- groups=groups, bias=bias,
|
|
|
- padding_mode=padding_mode,
|
|
|
- device=device, dtype=dtype
|
|
|
- )
|
|
|
+ kernel_size, stride,
|
|
|
+ padding=0, dilation=dilation,
|
|
|
+ groups=groups, bias=bias,
|
|
|
+ padding_mode=padding_mode,
|
|
|
+ device=device, dtype=dtype)
|
|
|
assert stride == 1
|
|
|
self.causal_padding = (kernel_size - 1, 0)
|
|
|
|
|
|
@@ -124,7 +123,8 @@ class ConditionalDecoder(nn.Module):
|
|
|
input_channel = output_channel
|
|
|
output_channel = channels[i]
|
|
|
is_last = i == len(channels) - 1
|
|
|
- resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
|
|
|
+ else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
transformer_blocks = nn.ModuleList(
|
|
|
[
|
|
|
BasicTransformerBlock(
|
|
|
@@ -138,14 +138,16 @@ class ConditionalDecoder(nn.Module):
|
|
|
]
|
|
|
)
|
|
|
downsample = (
|
|
|
- Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
|
+ Downsample1D(output_channel) if not is_last else \
|
|
|
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
|
)
|
|
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
|
|
|
|
for _ in range(num_mid_blocks):
|
|
|
input_channel = channels[-1]
|
|
|
out_channels = channels[-1]
|
|
|
- resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
|
|
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
|
|
|
transformer_blocks = nn.ModuleList(
|
|
|
[
|