|
|
@@ -210,6 +210,7 @@ class CausalAttention(Attention):
|
|
|
upcast_softmax: bool = False,
|
|
|
cross_attention_norm: Optional[str] = None,
|
|
|
cross_attention_norm_num_groups: int = 32,
|
|
|
+ qk_norm: Optional[str] = None,
|
|
|
added_kv_proj_dim: Optional[int] = None,
|
|
|
norm_num_groups: Optional[int] = None,
|
|
|
spatial_norm_dim: Optional[int] = None,
|
|
|
@@ -223,7 +224,7 @@ class CausalAttention(Attention):
|
|
|
processor: Optional["AttnProcessor2_0"] = None,
|
|
|
out_dim: int = None,
|
|
|
):
|
|
|
- super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups,
|
|
|
+ super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, qk_norm,
|
|
|
added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim)
|
|
|
processor = CausalAttnProcessor2_0()
|
|
|
self.set_processor(processor)
|
|
|
@@ -505,7 +506,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
if m.bias is not None:
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
- def forward(self, x, mask, mu, t, spks=None, cond=None):
|
|
|
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
|
|
"""Forward pass of the UNet1DConditional model.
|
|
|
|
|
|
Args:
|
|
|
@@ -540,7 +541,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
mask_down = masks[-1]
|
|
|
x = resnet(x, mask_down, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = (torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) == 1)
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
|
@@ -558,7 +559,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
for resnet, transformer_blocks in self.mid_blocks:
|
|
|
x = resnet(x, mask_mid, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = (torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) == 1)
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
|
@@ -574,7 +575,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
|
x = resnet(x, mask_up, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = (torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) == 1)
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
|
@@ -700,7 +701,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
|
self.initialize_weights()
|
|
|
|
|
|
- def forward(self, x, mask, mu, t, spks=None, cond=None):
|
|
|
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
|
|
"""Forward pass of the UNet1DConditional model.
|
|
|
|
|
|
Args:
|
|
|
@@ -735,7 +736,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
mask_down = masks[-1]
|
|
|
x, _, _ = resnet(x, mask_down, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
|
+ if streaming is True:
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
|
+ else:
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
|
x, _ = transformer_block(
|
|
|
@@ -753,7 +757,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
for resnet, transformer_blocks in self.mid_blocks:
|
|
|
x, _, _ = resnet(x, mask_mid, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
|
+ if streaming is True:
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
|
+ else:
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
|
x, _ = transformer_block(
|
|
|
@@ -769,7 +776,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
|
x, _, _ = resnet(x, mask_up, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
|
+ if streaming is True:
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
|
|
|
+ else:
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
|
|
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
|
x, _ = transformer_block(
|