|
|
@@ -270,11 +270,11 @@ class CausalConditionalCFM(ConditionalCFM):
|
|
|
# NOTE if smaller than flow_cache_size, means last chunk, no need to cache
|
|
|
if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
|
|
|
cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
|
|
|
- cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:,:,:,-flow_cache_size:]
|
|
|
+ cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
|
|
|
cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
|
|
|
- cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:,:,:,-flow_cache_size:]
|
|
|
+ cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
|
|
|
cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
|
|
|
- cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:,:,:,-flow_cache_size:]
|
|
|
+ cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
|
|
|
cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
|
|
|
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
|
|
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|