Giter VIP home page Giter VIP logo

mim-solutions / bert_for_longer_texts Goto Github PK

View Code? Open in Web Editor NEW
112.0 5.0 27.0 4.54 MB

BERT classification model for processing texts longer than 512 tokens. Text is first divided into smaller chunks and after feeding them to BERT, intermediate results are pooled. The implementation allows fine-tuning.

License: Other

Python 51.76% Shell 0.08% Jupyter Notebook 46.66% Makefile 0.66% Batchfile 0.83%
nlp natural-language-processing bert text-classification transfer-learning transformers deep-learning machine-learning pytorch roberta

bert_for_longer_texts's Introduction

BELT (BERT For Longer Texts)

πŸš€New in version 1.1.0: support for multilabel and regression. See the examplesπŸš€

Project description and motivation

The BELT approach

The BERT model can process texts of the maximal length of 512 tokens (roughly speaking tokens are equivalent to words). It is a consequence of the model architecture and cannot be directly adjusted. Discussion of this issue can be found here. Method to overcome this issue was proposed by Devlin (one of the authors of BERT) in the previously mentioned discussion: comment. The main goal of our project is to implement this method and allow the BERT model to process longer texts during prediction and fine-tuning. We dub this approach BELT (BERT For Longer Texts).

More technical details are described in the documentation. We also prepared the comprehensive blog post: part 1, part 2.

Attention is all you need, but 512 words is all you have

The limitations of the BERT model to the 512 tokens come from the very beginning of the transformers models. Indeed, the attention mechanism, invented in the groundbreaking 2017 paper Attention is all you need, scales quadratically with the sequence length. Unlike RNN or CNN models, which can process sequences of arbitrary length, transformers with the full attention (like BERT) are infeasible (or very expensive) to process long sequences. To overcome the issue, alternative approaches with sparse attention mechanisms were proposed in 2020: BigBird and Longformer.

BELT vs. BigBird vs. LongFormer

Let us now clarify the key differences between the BELT approach to fine-tuning and the sparse attention models BigBird and Longformer:

  • The main difference is that BigBird and Longformers are not modified BERTs. They are models with different architectures. Hence, they need to be pre-trained from scratch or downloaded.
  • This leads to the main advantage of the BELT approach - it uses any pre-trained BERT or RoBERTa models. A quick look at the HuggingFace Hub confirms that there are about 100 times more resources for BERT than for Longformer. It might be easier to find the one appropriate for the specific task or language.
  • On the other hand, we have not done any benchmark tests yet. We believe that the comparison of the BELT approach with the models with sparse attention might be very instructive. Some work in this direction was done in the 2022 paper Extend and Explain: Interpreting Very Long Language Models. The authors cited our implementation under the former name roberta_for_longer_texts. We encourage more research in this direction.

Installation and dependencies

The project requires Python 3.9+ to run. We recommend training the models on the GPU. Hence, it is necessary to install torch version compatible with the machine. The version of the driver depends on the machine - first, check the version of GPU drivers by the command nvidia-smi and choose the newest version compatible with these drivers according to this table (e.g.: 11.1). Then we install torch to get the compatible build. Here, we find which torch version is compatible with the CUDA version on our machine.

Another option is to use the CPU-only version of torch:

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Next, we recommend installing via pip:

pip3 install belt-nlp

If you want to clone the repo in order to run tests or notebooks, you can use the requirements.txt file.

Model classes

Two main classes are implemented:

  • BertClassifierTruncated - base binary classification model, longer texts are truncated to 512 tokens
  • BertClassifierWithPooling - extended model for longer texts (more details in the documentation)

Interface

The main methods are:

  • fit - fine-tune the model to the training set, use the list of raw texts and labels
  • predict_classes - calculate the list of classifications for the given list of raw texts. The model must be fine-tuned before that.
  • predict_scores - calculate the list of probabilities for the given list of raw texts. The model must be fine-tuned before that.

Loading the pre-trained model

As a default, the standard English bert-base-uncased model is used as a pre-trained model. However, it is possible to use any Bert or Roberta model. To do this, use the parameter pretrained_model_name_or_path. It can be either:

  • a string with the name of a pre-trained model configuration to download from huggingface library, e.g.: roberta-base.
  • a path to a directory with the downloaded model, e.g.: ./my_model_directory/.

Tests

To make sure everything works properly, run the command pytest tests -rA. As a default, during tests, models are trained on small samples on the CPU.

Examples

All examples use public datasets from huggingface hub.

Binary classification - prediction of sentiment of IMDB reviews

Multilabel classification - recognizing authors of Guardian articles

  • standard approach
  • belt
  • Notice the effectiveness of the BELT approach here: the test accuracy increased by 10%.

Regression - prediction of 1 to 5 rating based on reviews from Polish online e-commerce platform Allegro

Contributors

The project was created at MIM AI by:

If you want to contribute to the library, see the contributing info.

Version history

See CHANGELOG.md.

License

See the LICENSE file for license rights and limitations (MIT).

For Maintainers

File requirements.txt can be updated using the command:

bash pip-freeze-without-torch.sh > requirements.txt

This script saves all dependencies of the current active environment except torch.

In order to add the next version of the package to pypi, do the following steps:

  • First, increment the package version in pyproject.toml.
  • Then build the new version: run python3.9 -m build from the main folder.
  • Finally, upload to pypi: twine upload dist/* (two newly created files).

bert_for_longer_texts's People

Contributors

bm371613 avatar dependabot[bot] avatar jstremme avatar michalbrzozowski91 avatar mwachnicki avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

bert_for_longer_texts's Issues

A few general questions

Hello there! Thank you for this nice project ✨ @mwachnicki @MichalBrzozowski91
I'm really enjoying working through the details! I've just got a few general questions I hope you can help me with.

Let's consider Devlin's example and say we have a 3x6 mini-batch as a result of splitting our input sequence into 3 chunks:

the man went to the store
to the store and bought a
and bought a gallon of milk

BELT allows to process the mini-batch in one go and returns a single, pooled probability value as a result.

Question 1
As far as the attention mechanism goes, am I right to understand that this is applied separately to each chunk? In other words, the tokens in the first chunk do not attend to the tokens in the second and third one, correct?

Question 2
Devlin suggests applying an attention mask to ensure boundary words are not considered twice; in our example to the store and and bought a in the second and third chunk, respectively. Why don't we simply split the original sentence in a way that the chunks do not overlap? For example:

the man went to the store
and bought a gallon of milk

What is the purpose of keeping these overlapping bits if we have to mask them anyway?

Question 3
If my considerations in Question 1 are correct and attention is applied separately on each individual chunk, wouldn't it be beneficial to not mask the overlapping boundary words? Intuitively, I'd say this increases the context of each chunk, making them more similar to each other "in the eyes of the model".

Thanks again for the great work!

RuntimeError with the following message: "mat1 and mat2 shapes cannot be multiplied (2x512 and 768x1)

I'm encountering a RuntimeError with the following message: "mat1 and mat2 shapes cannot be multiplied (2x512 and 768x1)" when testing the fit and predict methods for a model with pooling using a pretrained model. Has anyone encountered this issue before, and if so, do you have any suggestions on how to resolve it?
Full errors log:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[22], line 1
----> 1 model.fit(X_train, y_train, epochs=1)

File /opt/conda/lib/python3.10/site-packages/belt_nlp/bert.py:80, in BertClassifier.fit(self, x_train, y_train, epochs)
     76 dataloader = DataLoader(
     77     dataset, sampler=RandomSampler(dataset), batch_size=self.batch_size, collate_fn=self.collate_fn
     78 )
     79 for epoch in range(epochs):
---> 80     self._train_single_epoch(dataloader, optimizer)

File /opt/conda/lib/python3.10/site-packages/belt_nlp/bert.py:126, in BertClassifier._train_single_epoch(self, dataloader, optimizer)
    123 for step, batch in enumerate(dataloader):
    125     labels = batch[-1].float().cpu()
--> 126     predictions = self._evaluate_single_batch(batch)
    127     loss = cross_entropy(predictions, labels) / self.accumulation_steps
    128     loss.backward()

File /opt/conda/lib/python3.10/site-packages/belt_nlp/bert_with_pooling.py:124, in BertClassifierWithPooling._evaluate_single_batch(self, batch)
    119 attention_mask_combined_tensors = torch.stack(
    120     [torch.tensor(x).to(self.device) for x in attention_mask_combined]
    121 )
    123 # get model predictions for the combined batch
--> 124 preds = self.neural_network(input_ids_combined_tensors, attention_mask_combined_tensors)
    126 preds = preds.flatten().cpu()
    128 # split result preds into chunks

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/belt_nlp/bert.py:180, in BertClassifierNN.forward(self, input_ids, attention_mask)
    177 x = x[0][:, 0, :]  # take <s> token (equiv. to [CLS])
    179 # classification head
--> 180 x = self.linear(x)
    181 x = self.sigmoid(x)
    182 return x

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x512 and 768x1)

I appreciate any insights or assistance provided!

Outputting Attentions

Is it possible to output the attentions of each chunk using output_attentions=True

MaskedLM for longer texts

Is it possible to apply the same logic on Roberta for the MaskedLM model? I need it for pretraining on a custom dataset that has long texts - Thanks

plz help me

  1. Is there a method for multi-class classification?

  2. Are you currently conducting research on multi-class classification?

  3. I get this warning message when training the model: "Token indices sequence length is longer than the specified maximum sequence length for this model (23716 > 512). Running this sequence through the model will result in indexing errors." Does this warning not affect the model training?

Managing GPU memory for token length more than 4000

Hi

Your code helped a lot to understand the chunking process. When i'm trying to fine tune using token length of 4000+ the model breaks with Out of memory exception. I have tried a batch size of 2 and on a larger 48GB GPU as well. I can see we are continuously pushing into GPU which causes memory exhaustion. Is there a way to better manage the memory for samples which are represented by 4000+ tokens.

QnA system using BERT

I'm trying to build a QnA system with bert where i will provide a pdf document. As 512 token is the limitation it's unable to take longer texts like a pdf.
I want to know which part of this repository I need to use for bert to work fine even with pdf.

Split Sizes Throws Error

@MichalBrzozowski91, thanks for this project! Really great stuff.

I get the following error in main.py on a four GPU setup when attempting to fine-tune a BERT model:

With batch size 1:
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 800 (input tensor's size at dimension 0), but got split_sizes=[16]

With batch size 4:
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 2750 (input tensor's size at dimension 0), but got split_sizes=[25, 6, 8, 16]

I would expect number_of_chunks to be of variable size for each record in the batch, but no matter my batch size, I seem to get an error at preds_split = preds.split(number_of_chunks) in main.py.

Any idea what I might be missing?

text length warning

Token indices sequence length is longer than the specified maximum sequence length for this model (2268 > 512). Running this sequence through the model will result in indexing errors
Can I ignore the warning notice above? Why is it popping up?

Obtain embedding vectors

Hello and thank you for sharing your work!
I would like to know if there is a way to obtain embedding vectors of one (or more) sentences fed into the model.
Hope you could help me.
Thank you in any case

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.