Giter VIP home page Giter VIP logo

tr-bert's Introduction

TR-BERT

Source code and dataset for "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference".

model

The code is based on huggaface's transformers. Thanks to them! We will release all the source code in the future.

Requirement

Install dependencies and apex:

pip3 install -r requirement.txt
pip3 install --editable transformers

Pretrained models

Download the DistilBERT-3layer and BERT-1024 from Google Drive/Tsinghua Cloud.

Classfication

Download the IMDB, Yelp, 20News datasets from Google Drive/Tsinghua Cloud.

Download the Hyperpartisan dataset, and randomly split it into train/dev/test set: python3 split_hyperpartisan.py

Train BERT/DistilBERT Model

Use flag --do train:

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path bert-base-uncased --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 16 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 5  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval  --evaluate_during_training  --do_train

where task_name can be set as imdb/yelp_f/20news/hyperpartisan for different tasks and model type can be set as bert/distilbert for different models.

Compute Graident for Residual Strategy

Use flag --do_eval_grad.

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval_grad

This step doesn't supoort data DataParallel or DistributedDataParallel currently and should be done in a single GPU.

Train the policy network solely

Start from the checkpoint from the task-specific fine-tuned model. Change model_type from bert to autobert, and run with flag --do_train --train_rl:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/auto_1  --do_lower_case  --do_train --train_rl --alpha 1 --guide_rate 0.5

where alpha is the harmonic coefficient for the length punishment and guide_rate is the proportion of imitation learning steps. model_type can be set as autobert/distilautobert for applying token reduction to BERT/DistilBERT.

Compute Logits for Knowledge Distilation

Use flag --do_eval_logits.

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval_logits

This step doesn't supoort data DataParallel or DistributedDataParallel currently and should be done in a single GPU.

Train the whole network with both the task-specifc objective and RL objective

Start from the checkpoint from --train_rl model and run with flag --do_train --train_both --train_teacher:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/auto_1 --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 1 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/auto_1_both  --do_lower_case  --do_train --train_both --train_teacher --alpha 1

Evaluate

Use flag --do_eval:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/auto_1_both  --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 1  --output_dir imdb_models/auto_1_both  --do_lower_case  --do_eval --eval_all_checkpoints

When the batch size is more than 1 in evaluating, we will remain the same number of tokens for each instance in the same batch.

Initialize

For IMDB dataset, we find that when we directly initialize the selector with heuristic objective before train the policy network solely, we can get a bit better performance. For other datasets, this step makes little change. Run this step with flag --do_train --train_init:

python3 trans_imdb_rank.py
python3 run_classification.py  --task_name imdb  --model_type initbert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/bert_init  --do_lower_case  --do_train --train_init 

Question Answering

Download the SQuAD 2.0 dataset.

Download the MRQA dataset with our split] from Google Drive/Tsinghua Cloud.

Download the HotpotQA dataset from the Transformer-XH repository, where paragraphs are retrieved for each question according to TF-IDF, entity linking and hyperlink and re-ranked by BERT re-ranker.

Download the TriviaQA dataset, where paragraphs are re-rank by the linear passage re-ranker in DocQA.

Download the WikiHop dataset.

The whole training progress of question answer models is similiar to text classfication models, with flags --do_train, --do_train --train_rl, --do_train --train_both --train_teacher in turn. The codes of each dataset:

SQuAD: run_squad.py with flag version_2_with_negative

NewsQA / NaturalQA: run_mrqa.py

RACE: run_race_classify.py

HotpotQA: run_hotpotqa.py

TriviaQA: run_triviaqa.py

WikiHop: run_wikihop.py

Harmonic Coefficient Lambda

The example harmonic coefficients are shown as follows:

Dataset train_rl train_both
SQuAD 2.0 5 5
NewsQA 3 5
NaturalQA 2 2
RACE 0.5 0.1
YELP.F 2 0.5
20News 1 1
IMDB 1 1
HotpotQA 0.1 4
TriviaQA 0.5 1
Hyperparisan 0.01 0.01

Cite

If you use the code, please cite this paper:

@inproceedings{ye2021trbert,
  title={TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference},
  author={Deming Ye, Yankai Lin, Yufei Huang, Maosong Sun},
  booktitle={Proceedings of NAACL 2021},
  year={2021}
}

tr-bert's People

Contributors

yedeming avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.