Giter VIP home page Giter VIP logo

han's Issues

RuntimeError: The size of tensor a (768) must match the size of tensor b (30522) at non-singleton dimension 1

Hi, I trained the model using Google Colab while using the T4 GPU.
I had faced an out-of-memory error while training, so I reduced the batch size to 64. After doing so, I am getting the following error regarding the mismatch of tensor sizes. I have used train.csv, dev.csv and test.csv from the dataset: https://zenodo.org/record/7095100.

I had used the following code in depression_classifier.py to train the model using bert-base-uncased

from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")

As someone new to this architecture, please let me know if there are any changes that can be made in depression_classifier.py to incorporate the new batch size to solve this issue.
Thank you.

RuntimeError Traceback (most recent call last)
in <cell line: 6>()
108
109 print("model training in batches [size: %s]" % train_batch_size)
--> 110 model_training(train_set_path, heldout_set_path, evaluation_data_path, no_gpu, train_batch_size,
111 model_file_prefix,
112 num_epochs=num_epochs,

10 frames
/content/depression_classifier.py in model_training(train_set_path, validation_set_path, test_set_path, n_gpu, train_batch_size, model_file_prefix, num_epochs, max_post_size_option)
601 cuda_device=n_gpu
602 )
--> 603 trainer.train()
604 timestamped_print("done.")
605

/usr/local/lib/python3.10/dist-packages/allennlp/training/gradient_descent_trainer.py in train(self)
769
770 try:
--> 771 metrics, epoch = self._try_train()
772 return metrics
773 finally:

/usr/local/lib/python3.10/dist-packages/allennlp/training/gradient_descent_trainer.py in _try_train(self)
791 for epoch in range(self._num_epochs):
792 epoch_start_time = time.time()
--> 793 train_metrics = self._train_epoch(epoch)
794
795 if self._epochs_completed < self._start_after_epochs_completed:

/usr/local/lib/python3.10/dist-packages/allennlp/training/gradient_descent_trainer.py in _train_epoch(self, epoch)
508
509 with amp.autocast(self._use_amp):
--> 510 batch_outputs = self.batch_outputs(batch, for_training=True)
511 batch_group_outputs.append(batch_outputs)
512 loss = batch_outputs["loss"]

/usr/local/lib/python3.10/dist-packages/allennlp/training/gradient_descent_trainer.py in batch_outputs(self, batch, for_training)
401 returns, after adding any specified regularization penalty to the loss (if training).
402 """
--> 403 output_dict = self._pytorch_model(**batch)
404
405 if for_training:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1146 input = bw_hook.setup_input_hook(input)
1147
-> 1148 result = forward_call(input, **kwargs)
1149 if _global_forward_hooks or self._forward_hooks:
1150 for hook in (
_global_forward_hooks.values(), *self._forward_hooks.values()):

/content/depression_classifier.py in forward(self, user_id, label)
235
236 content_tensor_in_batch_padded, batch_content_mask, metaphor_tensor_in_batch_padded, batch_metaphor_mask
--> 237 = self.batch_encoding(user_id, maximum_sequence_length=self.max_post_size)
238
239 timestamped_print("encode social context with LSTM")

/content/depression_classifier.py in batch_encoding(self, user_ids, maximum_sequence_length)
352 batch_metaphor_mask = torch.Tensor()
353
--> 354 metaphor_tensor_in_batch_padded, batch_metaphor_mask = self.padding_and_norm_propagation_tensors(
355 metaphor_tensor_in_batch,
356 maximum_sequence_length=maximum_sequence_length)

/content/depression_classifier.py in padding_and_norm_propagation_tensors(self, list_of_context_seq_tensor, maximum_sequence_length)
416 timestamped_print("done")
417
--> 418 raise err
419
420 timestamped_print("tensor size after padding: %s" % str(

/content/depression_classifier.py in padding_and_norm_propagation_tensors(self, list_of_context_seq_tensor, maximum_sequence_length)
404 list_of_context_seq_tensor.insert(0, torch.zeros(maximum_sequence_length, feature_dim))
405 # zero padding
--> 406 batch_propagation_tensors = torch.nn.utils.rnn.pad_sequence(list_of_context_seq_tensor, batch_first=True)
407 # remove the dummy tensor
408 batch_propagation_tensors = batch_propagation_tensors[1:]

/usr/local/lib/python3.10/dist-packages/torch/nn/utils/rnn.py in pad_sequence(sequences, batch_first, padding_value)
394 # assuming trailing dimensions and type of all the Tensors
395 # in sequences are same and fetching those from sequences[0]
--> 396 return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
397
398

RuntimeError: The size of tensor a (768) must match the size of tensor b (30522) at non-singleton dimension 1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.