|
@@ -123,8 +123,8 @@ class ConditionalDecoder(nn.Module):
|
|
|
input_channel = output_channel
|
|
input_channel = output_channel
|
|
|
output_channel = channels[i]
|
|
output_channel = channels[i]
|
|
|
is_last = i == len(channels) - 1
|
|
is_last = i == len(channels) - 1
|
|
|
- resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal \
|
|
|
|
|
- else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
|
|
|
|
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
|
|
|
|
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
transformer_blocks = nn.ModuleList(
|
|
transformer_blocks = nn.ModuleList(
|
|
|
[
|
|
[
|
|
|
BasicTransformerBlock(
|
|
BasicTransformerBlock(
|
|
@@ -138,7 +138,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
]
|
|
]
|
|
|
)
|
|
)
|
|
|
downsample = (
|
|
downsample = (
|
|
|
- Downsample1D(output_channel) if not is_last else \
|
|
|
|
|
|
|
+ Downsample1D(output_channel) if not is_last else
|
|
|
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
|
)
|
|
)
|
|
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
@@ -147,7 +147,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
input_channel = channels[-1]
|
|
input_channel = channels[-1]
|
|
|
out_channels = channels[-1]
|
|
out_channels = channels[-1]
|
|
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
|
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
|
|
- ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
|
|
|
|
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
|
|
|
|
|
transformer_blocks = nn.ModuleList(
|
|
transformer_blocks = nn.ModuleList(
|
|
|
[
|
|
[
|
|
@@ -251,7 +251,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
|
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
|
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
|
- attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
|
|
|
|
|
|
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
x = transformer_block(
|
|
|
hidden_states=x,
|
|
hidden_states=x,
|
|
@@ -270,7 +270,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
|
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
|
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
|
- attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
|
|
|
|
|
|
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
x = transformer_block(
|
|
|
hidden_states=x,
|
|
hidden_states=x,
|
|
@@ -287,7 +287,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
|
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
|
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
|
- attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
|
|
|
|
|
|
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
x = transformer_block(
|
|
|
hidden_states=x,
|
|
hidden_states=x,
|
|
@@ -298,4 +298,4 @@ class ConditionalDecoder(nn.Module):
|
|
|
x = upsample(x * mask_up)
|
|
x = upsample(x * mask_up)
|
|
|
x = self.final_block(x, mask_up)
|
|
x = self.final_block(x, mask_up)
|
|
|
output = self.final_proj(x * mask_up)
|
|
output = self.final_proj(x * mask_up)
|
|
|
- return output * mask
|
|
|
|
|
|
|
+ return output * mask
|