Giter VIP home page Giter VIP logo

kssteven418 / ltp Goto Github PK

View Code? Open in Web Editor NEW
85.0 3.0 14.0 41.1 MB

[KDD'22] Learned Token Pruning for Transformers

Home Page: https://arxiv.org/abs/2107.00910

License: Apache License 2.0

Shell 0.23% Makefile 0.02% Dockerfile 0.05% Jsonnet 0.01% Python 94.18% Jupyter Notebook 5.52%
natural-language-processing transformer bert pruning model-compression efficient-model efficient-neural-networks

ltp's Introduction

LTP: Learned Token Pruning for Transformers

Screenshot from 2021-07-08 13-39-02

Check our paper for more details.

Installation

We follow the same installation procedure as the original Huggingface transformer repo.

pip install sklearn scipy datasets torch
pip install -e .  # in the top directory

Prepare Checkpoints

LTP is implemented on top of Huggingface transformer's I-BERT implementation. Therefore, we first need to generate a checkpoint file of ibert finetuned on the target downstream task. While you can do this on the original Huggingface repository, we also support our base branch ltp/base where you can run the following code to finetune ibert on the GLUE tasks.

git checkout ltp/base
cd examples/text-classification
python run_glue.py --model_name_or_path kssteven/ibert-roberta-base --output_dir {CKPT} --task {TASK} --do_train --do_eval {--some_more_arguments}
  • {TASK}: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI
  • Please refer to the Huggingface tutorial and the official documentation for more details in arguments and hyperparameters.
  • Note that as default ibert behaves the same as roberta (see this tutorial), hence the resulting model will be the same as roberta-base finetuned on the target GLUE task.

The final model will be checkpointed in {CKPT}.

  • Remove {CKPT}/trainer_state.json.
  • In the configuration file {CKPT}/config.json, change (1) "architectures" to ["LTPForSequenceClassification"] and (2) "model_type" to "ltp".

Run Learned Token Pruning

Add the following lines in the configuration file {CKPT}/config.json.

"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01, 

final_token_threshold determines the token threshold of the last layer, and the thresholds of the remaining layers will be linearly scaled. For instance, the thresholds for the 3rd, 6th, and 9th layers will be 0.0025, 0.005, and 0.0075, respectively, when setting the final_token_threshold , i.e., the threshold for the last (12th) layer, to 0.01. This number is a hyperparameter, and we found that 0.01 works well in many cases.

The learnable mode consists of 2 stages: soft threshold and hard threshold. Please refer to our paper for more details.

1. Soft Threshold

We first train the model using the soft threshold mode. This trains the thresholds as well as the model parameters to search for the best threshold configuration.

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr 2e-5 --temperature {T}\
  --lambda 0.1 --weight_decay 0 --bs 64 --masking_mode soft --epoch {epoch} --save_step 100 --no_load
  • {TASK}: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI
  • You can assign different learning rate for lr, but 2e-5 worked fine.
  • We set {epoch} to be 10 for smaller datasets (e.g., RTE, MRPC) and 1 for larger datasets (e.g., SST2, QNLI, MRPC).
  • --no_load flag will not load the best model at the end of the training (i.e., the final checkpoint will be the one at the end of training).
  • lambda is an important hyperparameter than controls the pruning level: the higher the value, the more we prune tokens. 0.01 ~ 0.2 worked well in many cases, but we recommend the user to empirically search for the best number for it.
  • temperature is another hyperparameter, and 1e-3 ~ 1e-5 worked well. In the paper, we searched over {1e−4, 2e−4, 5e−4, 1e−3, 2e−3}.

The final model will be checkpointed in {CKPT_soft} = checkpoints/base/{TASK}/absolute_threshold/rate_{final_token_threshold}/temperature_{T}/lambda_{lambda}/lr_{lr}. Remove trainer_state.json from the checkpoint file in {CKPT_soft}.

2. Hard Threshold

Once we learn the thresholds, we fix those values, turn back to the hard threshold mode, and finetune the model parameters only.

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT_soft} --lr {LR} --bs 64 --masking_mode hard --epoch 5 
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.

The final model will be checkpointed in {CKPT_soft}/hard/lr_{LR}.

Run Baseline Methods

We additionally provide code to reproduce the baseline methods used in our paper (i.e., top-k and manual threshold).

Top-k Token Pruning

Add the following lines in {CKPT}/config.json.

"prune_mode": "topk",
"token_keep_rate": 0.2,

The token keep rates of the first three layers and the last layer are 1 and token_keep_rate, respectively. The keep rates of the remaining layers are scaled linearly. The smaller token_keep_rate is, the more aggressive we prune tokens. You can also assign negative number for token_keep_rate and, in that case, the keep rate of each layer will be assigned as max(0, keep_rate).

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.

The final model will be checkpointed in {CKPT}/topk/lr_{LR}.

Manual (Non-learnable) Threshold Pruning

Add the following lines in {CKPT}/config.json.

"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01, 

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5 --save_step 500
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.
  • Note that the only difference from the learned token pruning mode is that we run the hard threshold mode from the beginning.

The final model will be checkpointed in {CKPT}/hard/lr_{LR}.

Copyright

THIS SOFTWARE AND/OR DATA WAS DEPOSITED IN THE BAIR OPEN RESEARCH COMMONS REPOSITORY ON 02/07/23.

ltp's People

Contributors

kssteven418 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

ltp's Issues

Why don't mask during Testing?

if self.training and not self.hard_masking:

    if self.training and not self.hard_masking:
        if pruner_outputs is not None:
            threshold, pruning_scores = pruner_outputs['threshold'], pruner_outputs['scores']
            self.mask = torch.sigmoid((pruning_scores - threshold) / self.temperature)
            layer_output = layer_output * self.mask.unsqueeze(-1)

Does attention mask reduce computation cost?

Hey, there.
After I read the code, I am confused that the computation cost can be reduced by mask more tokens. Did I miss anything?

PS. I see the FLOPS is calculated by the length of tokens retented at each layer which is counted during inference.

Do you have inference latency metrics?

Inference about hard pruning

When hard pruning inference, tokens below the threshold will be discarded and do not enter the calculation of the feed-forward layer, but when entering the feed-forward layer after normalization and other operations, the position of the pruned token is not equal to 0, that is, the calculation will still be carried out, also when moving to the next layer to calculate the Q,K matrix. So where does his accelerated inference manifest itself?

FLOPs

Since it is a dynamic transformer, the GFLOPs of each instance input is different. How to calculate the FLOPs of the entire model? Take the average FLOPs of all validation sets?

question about the max seq length

🖥 Benchmarking transformers

Hi there,

When I run one of the examples in the text classification folder, and pass max_seq_length =1024 to the model, I got the following warning, which says: WARNING - main - The max_seq_length passed (1024) is larger than the maximum length for the model (512). Using max_seq_length=512.

Set-up

I'm runing on GPU node with the following command.
python ./examples/text-classification/run_glue.py
--model_name_or_path bert-base-cased
--task_name mrpc
--do_train
--do_eval
--max_seq_length 1024
--per_device_train_batch_size 8
--learning_rate 2e-5
--num_train_epochs 1
--overwrite_output_dir
--output_dir /tmp/mrpc/

It can still give me a output. But instead of using the max_seq_length as 1024, it uses max_seq_length=512.

I'm wondering if this is due to the model is still limited to the 512 max token length in memory requirement like most transformer and bert-based models. Or is this caused by the default configuration in the pre-training process? And in the paper, the author mentioned two settings and one of them is 1024, so how can I get the pretained model with max_seq_length=1024? Thanks!

Cannot run with or without installing transformer.

Hi, I found I cannot run the learned token pruning whether I install the transformers or not.

If I install transformer, it will raise the error of ValueError: Some specified arguments are not used by the HfArgumentParser: ['--masking_mode', 'soft', '--weight_decay_threshold', '0.0', '--lambda_threshold', '0.1', '--temperature', '0.0001']. I think it is because these arguments are customized in src/transformers.

If I go with an environment which doesn't install the transformer, it will have ModuleNotFoundError: No module named 'transformers'. It simply cannot import from /src .

If I force the code of run_glue_ltp with sys.path.insert(1, './src/'). It will have the error of pkg_resources.DistributionNotFound: The 'sacremoses' distribution was not found and is required by this application .

Can you tell me how to run the code properly? Thanks.

No mask used in evaluation process

Can you show that how you evaluate the model performance with 'attention_mask' ?

according to this line:
https://github.com/kssteven418/LTP/blob/8ab31a623fb71c5f4f8208e878097f214484e848/src/transformers/models/ltp/modeling_ltp.py#L305C27-L305C27

the 'attention_mask' is never used outside the for loop.

So, I think you did not use the attention mask in the evaluation part, because you must need this mask for those labels.

Can you show me your evaluation process? (with some token pruned)

Some specified arguments are not used by the HfArgumentParser

when I run python run.py --arch ltp-base --task SST2 --restore pretrained/bert-base-uncased-SST-2 --lr 2e-5 --temperature 2e-3 --lambda_threshold 0.1 --weight_decay 0 --bs 64 --masking_mode soft --epoch 10 --save_step 100 --no_load

Some specified arguments are not used by the HfArgumentParser.

I found arguments in parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) ,do not have masking_mode, lr_threshold, weight_decay_threshold and some others

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.