|
@@ -158,12 +158,9 @@ class CausalAttnProcessor2_0(AttnProcessor2_0):
|
|
|
|
|
|
|
|
key_cache = attn.to_k(encoder_hidden_states)
|
|
key_cache = attn.to_k(encoder_hidden_states)
|
|
|
value_cache = attn.to_v(encoder_hidden_states)
|
|
value_cache = attn.to_v(encoder_hidden_states)
|
|
|
- # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
|
|
|
|
|
- if cache.size(0) != 0:
|
|
|
|
|
- key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
|
|
|
|
- value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
|
|
|
|
- else:
|
|
|
|
|
- key, value = key_cache, value_cache
|
|
|
|
|
|
|
+ # NOTE always concat cache for interface compatibility
|
|
|
|
|
+ key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
|
|
|
|
|
+ value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
|
|
|
cache = torch.stack([key_cache, value_cache], dim=3)
|
|
cache = torch.stack([key_cache, value_cache], dim=3)
|
|
|
|
|
|
|
|
inner_dim = key.shape[-1]
|
|
inner_dim = key.shape[-1]
|
|
@@ -799,6 +796,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
|
|
|
output = self.final_proj(x * mask_up)
|
|
output = self.final_proj(x * mask_up)
|
|
|
return output * mask
|
|
return output * mask
|
|
|
|
|
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
|
|
def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
|
|
|
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
|
|
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|
|
down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
|