Giter VIP home page Giter VIP logo

bigbird's Introduction

BigBird

This repository tracks my work related to porting Google's BigBird to 🤗 Transformers. I trained 🤗's BigBirdModel & FlaxBigBirdModel (with suitable heads) on some of datasets mentioned in the paper: Big Bird: Transformers for Longer Sequences. This repository hosts scripts for those training as well.

You can find the quick demo in 🤗spaces: https://hf.co/spaces/vasudevgupta/BIGBIRD_NATURAL_QUESTIONS

Checkout following notebooks for diving deeper into using 🤗 BigBird:

Description Notebook
Flax BigBird evaluation on natural-questions dataset Open In Colab
PyTorch BigBird evaluation on natural-questions dataset Open In Colab
PyTorch BigBirdPegasus evaluation on PubMed dataset Open In Colab
How to use 🤗's BigBird (RoBERTa & Pegasus) for inference Open In Colab

Updates @ 🤗

Description Dated Link
Script for training FlaxBigBird (with QA heads) on natural-questions June 25, 2021 PR #12233
Added Flax/Jax BigBird-RoBERTa to 🤗Transformers June 15, 2021 PR #11967
Added PyTorch BigBird-Pegasus to 🤗Transformers May 7, 2021 PR #10991
Published blog post @ 🤗Blog March 31, 2021 Link
Added PyTorch BigBird-RoBERTa to 🤗Transformers March 30, 2021 PR #10183

Training BigBird

I have trained BigBird on natural-questions dataset. This dataset takes around 100 GB of space on a disk. Before diving deeper into scripts, let's set up the system using the following commands:

# clone my repository
git clone https://github.com/vasudevgupta7/bigbird

# install requirements
cd bigbird
pip3 install -r requirements.txt

# switch to code directory
cd src

# create data directory for preparing natural questions
mkdir -p data

Now that your system is ready let's preprocess & prepare the dataset for training. Just run the following commands:

# this will download ~ 100 GB dataset from 🤗 Hub & prepare training data in `data/nq-training.jsonl`
PROCESS_TRAIN=true python3 prepare_natural_questions.py

# for preparing validation data in `data/nq-validation.jsonl`
PROCESS_TRAIN=false python3 prepare_natural_questions.py

The above commands will first download the dataset from 🤗 Hub & then will prepare it for training. Remember this will download ~ 100 GB of the dataset, so you need to have a good internet connection & enough space (~ 250 GB free space). Preparing the dataset will take ~ 3 hours.

Now that you have prepared the dataset let's start training. You have two options here:

  1. Train PyTorch version of BigBird with 🤗 Trainer
  2. Train FlaxBigBird with custom training loop

PyTorch BigBird distributed training on multiple GPUs

# For distributed training (using nq-training.jsonl & nq-validation.jsonl) on 2 gpus
python3 -m torch.distributed.launch --nproc_per_node=2 train_nq_torch.py

Flax BigBird distributed training on TPUs/GPUs

# start training
python3 train_nq_flax.py

# For hparams tuning, try wandb sweep (`random search` is happening by default):
wandb sweep sweep_flax.yaml
wandb agent <agent-id-created-by-above-CMD>

You can find my fine-tuned checkpoints on HuggingFace Hub. Refer to the following table:

Checkpoint Description
flax-bigbird-natural-questions Obtained by running train_nq_flax.py script
bigbird-roberta-natural-questions Obtained by running train_nq_torch.py script

To see how the above checkpoint performs on the QA task, check out this:

Context is just a tweet taken from 🤗 Twitter Handle. 💥💥💥

bigbird's People

Contributors

bluehephaestus avatar patrickvonplaten avatar thevasudevgupta avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

bigbird's Issues

Remove 50% of data?

Hi, I was looking over the code, and it seems on a null answer you remove "50 % samples":

if cat == "null" and np.random.rand() < 0.6:
continue # removing 50 % samples

However, I was wondering why you did this; I think the original paper that BigBird based their processing off (here) downsampled 50x less, not 50% less.

QA system

Hi,

thanks for your good work!

I'm testing your model (BigBirdRobertaQA) in a retriever-reader architecture. I retrieve the top 5 paragraphs (they are a couple of sentences each), merge them (joining them with a space between) and feed them as the context to the model.

  1. Is that the best way to approach this? If so, what is the recommended max length for the tokenizer?
  2. I expect not, but is there a way to influence what kind of answer will you get? (short, long, yes/no)

BigBirdForNaturalQuestions

Hi Vasudevgupta,

I am interested in predicting category (long, short) answers, but as per your suggestion from the hugging face website (https://huggingface.co/vasudevgupta/bigbird-roberta-natural-questions), I am trying to use BigBirdForNaturalQuestions, but not able to find this class/model anywhere. Could you please help me on that where to get BigBirdForNaturalQuestions and how to get the long and short answers from that model?

Regards,
Kishore

Can I use roberta-base checkpoint for this?

Hi Gupta,

Thank you for your hard work.
I want to use Bigbird for my own dataset in another language. But google/bigbird-roberta-base checkpoint was trained in English language. So can I load another checkpoint (for example roberta-base) to your model?

BigBird Pegaus Training example

Hi @vasudevgupta7,

First of all nice job with merging BigBird to huggingface transformers!
I wanted to try training my own BigBird Pegasus model, and I saw that you traiend multiple of these. Browsing both huggingface and your repositories, I did not meet an example of these, and wanted to ask if you have a simple example that you want to share that i can get started on?

ImportError: cannot import name 'CATEGORY_MAPPING' from 'params'

Hi, I am getting import error in params for the Category_Mapping. Could you please help me on that?

import torch
import numpy as np
from bigbird.src.train_nq_torch import BigBirdForNaturalQuestions
from params import CATEGORY_MAPPING
from transformers import BigBirdTokenizer

CATEGORY_MAPPING = {v: k for k, v in CATEGORY_MAPPING.items()}
CATEGORY_MAPPING


ImportError Traceback (most recent call last)
in ()
2 import numpy as np
3 from bigbird.src.train_nq_torch import BigBirdForNaturalQuestions
----> 4 from params import CATEGORY_MAPPING
5 from transformers import BigBirdTokenizer
6

ImportError: cannot import name 'CATEGORY_MAPPING' from 'params' (/usr/local/lib/python3.7/dist-packages/params/init.py)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.