I'm encountering a RuntimeError with the following message: "mat1 and mat2 shapes cannot be multiplied (2x512 and 768x1)" when testing the fit and predict methods for a model with pooling using a pretrained model. Has anyone encountered this issue before, and if so, do you have any suggestions on how to resolve it?
Full errors log:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[22], line 1
----> 1 model.fit(X_train, y_train, epochs=1)
File /opt/conda/lib/python3.10/site-packages/belt_nlp/bert.py:80, in BertClassifier.fit(self, x_train, y_train, epochs)
76 dataloader = DataLoader(
77 dataset, sampler=RandomSampler(dataset), batch_size=self.batch_size, collate_fn=self.collate_fn
78 )
79 for epoch in range(epochs):
---> 80 self._train_single_epoch(dataloader, optimizer)
File /opt/conda/lib/python3.10/site-packages/belt_nlp/bert.py:126, in BertClassifier._train_single_epoch(self, dataloader, optimizer)
123 for step, batch in enumerate(dataloader):
125 labels = batch[-1].float().cpu()
--> 126 predictions = self._evaluate_single_batch(batch)
127 loss = cross_entropy(predictions, labels) / self.accumulation_steps
128 loss.backward()
File /opt/conda/lib/python3.10/site-packages/belt_nlp/bert_with_pooling.py:124, in BertClassifierWithPooling._evaluate_single_batch(self, batch)
119 attention_mask_combined_tensors = torch.stack(
120 [torch.tensor(x).to(self.device) for x in attention_mask_combined]
121 )
123 # get model predictions for the combined batch
--> 124 preds = self.neural_network(input_ids_combined_tensors, attention_mask_combined_tensors)
126 preds = preds.flatten().cpu()
128 # split result preds into chunks
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/belt_nlp/bert.py:180, in BertClassifierNN.forward(self, input_ids, attention_mask)
177 x = x[0][:, 0, :] # take <s> token (equiv. to [CLS])
179 # classification head
--> 180 x = self.linear(x)
181 x = self.sigmoid(x)
182 return x
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x512 and 768x1)