lyuxiang.lx 7 months ago
parent
commit
c07cd3d730
1 changed files with 3 additions and 3 deletions
  1. 3 3
      cosyvoice/flow/flow_matching.py

+ 3 - 3
cosyvoice/flow/flow_matching.py

@@ -270,11 +270,11 @@ class CausalConditionalCFM(ConditionalCFM):
             # NOTE if smaller than flow_cache_size, means last chunk, no need to cache
             if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
                 cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
-                cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:,:,:,-flow_cache_size:]
+                cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
                 cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
-                cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:,:,:,-flow_cache_size:]
+                cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
                 cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
-                cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:,:,:,-flow_cache_size:]
+                cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
                 cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
             dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
             dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)