Here's the full output from that step in case it helps; I had to set device='cpu' to get a useful error message...
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [28], in <cell line: 9>()
20 optimizer.zero_grad()
22 with torch.cuda.amp.autocast():
---> 23 out = gpt.forward(**batch,)
25 loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),
26 reduction='mean', label_smoothing=0.1)
28 print(loss)
File ~/miniconda3/envs/pytorch/lib/python3.8/site-packages/transformers/models/gptj/modeling_gptj.py:805, in GPTJForCausalLM.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
800 hidden_states = hidden_states.to(self.lm_head.weight.device)
802 # make sure sampling in fp16 works correctly and
803 # compute loss in fp32 to match with mesh-tf version
804 # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
--> 805 lm_logits = self.lm_head(hidden_states).to(torch.float32)
807 loss = None
808 if labels is not None:
809 # Shift so that tokens < n predict n
File ~/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
Input In [5], in FrozenBNBLinear.forward(self, input)
13 output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
14 if self.adapter:
---> 15 output += self.adapter(input)
16 return output
RuntimeError: Output 0 of DequantizeAndLinearBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.