lyuxiang.lx 1 년 전
부모
커밋
3770c1c8b1
1개의 변경된 파일6개의 추가작업 그리고 3개의 파일을 삭제
  1. 6 3
      cosyvoice/flow/decoder.py

+ 6 - 3
cosyvoice/flow/decoder.py

@@ -158,9 +158,12 @@ class CausalAttnProcessor2_0(AttnProcessor2_0):
 
         key_cache = attn.to_k(encoder_hidden_states)
         value_cache = attn.to_v(encoder_hidden_states)
-        # NOTE always concat cache for interface compatibility
-        key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
-        value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
+        # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
+        if cache.size(0) != 0:
+            key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
+            value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
+        else:
+            key, value = key_cache, value_cache
         cache = torch.stack([key_cache, value_cache], dim=3)
 
         inner_dim = key.shape[-1]