|
|
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
|
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
|
|
- y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1,
|
|
|
|
|
|
|
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
|
|
att_cache=att_cache, cnn_cache=cnn_cache,
|
|
att_cache=att_cache, cnn_cache=cnn_cache,
|
|
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|