Giter VIP home page Giter VIP logo

vanillakd's Introduction

VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale

Official PyTorch implementation of VanillaKD, from the following paper:
VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale
Zhiwei Hao, Jianyuan Guo, Kai Han, Han Hu, Chang Xu, Yunhe Wang

This paper emphasizes the importance of scale in achieving superior results. It reveals that previous KD methods designed solely based on small-scale datasets has underestimated the effectiveness of vanilla KD on large-scale datasets, which is referred as to small data pitfall. By incorporating stronger data augmentation and larger datasets, the performance gap between vanilla KD and other approaches is narrowed:

Without bells and whistles, state-of-the-art results are achieved for ResNet-50, ViT-S, and ConvNeXtV2-T models on ImageNet, showcasing the vanilla KD is elegantly simple but astonishingly effective in large-scale scenarios.

If you find this project useful in your research, please cite:

@article{hao2023vanillakd,
  title={VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale},
  author={Hao, Zhiwei and Guo, Jianyuan and Han, Kai and Hu, Han and Xu, Chang and Wang, Yunhe},
  journal={arXiv preprint arXiv:2305.15781},
  year={2023}
}

Model Zoo

We provide models trained by vanilla KD on ImageNet.

name acc@1 acc@5 model
resnet50 83.08 96.35 model
vit_tiny_patch16_224 78.11 94.26 model
vit_small_patch16_224 84.33 97.09 model
convnextv2_tiny 85.03 97.44 model

Usage

First, clone the repository locally:

git clone https://github.com/Hao840/vanillaKD.git

Then, install PyTorch and timm 0.6.5

conda install -c pytorch pytorch torchvision
pip install timm==0.6.5

Our results are produced with torch==1.10.2+cu113 torchvision==0.11.3+cu113 timm==0.6.5. Other versions might also work.

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is:

│path/to/imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Evaluation

To evaluate a distilled model on ImageNet val with a single GPU, run:

python validate.py /path/to/imagenet --model <model name> --checkpoint /path/to/checkpoint

Training

To train a ResNet50 student using BEiTv2-B teacher on ImageNet on a single node with 8 GPUs, run:

Strategy A2:

python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss kd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1

Strategy A1:

python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss kd --amp --epochs 600 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.01 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.1 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.2 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1

Commands for reproducing baseline results:

DKD Training with ResNet50 student, BEiTv2-B teacher, and strategy A2 for 300 epochs
python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss dkd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1
DIST Training with ResNet50 student, BEiTv2-B teacher, and strategy A2 for 300 epochs
python -m torch.distributed.launch --nproc_per_node=8 train-kd.py /path/to/imagenet --model resnet50 --teacher beitv2_base_patch16_224 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss dist --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 1
Correlation Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs
python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss correlation --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0
RKD Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs
python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss rkd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0
ReviewKD Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs
python -m torch.distributed.launch --nproc_per_node=8 train-fd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss review --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0
CRD Training with ResNet50 student, ResNet152 teacher, and strategy A2 for 300 epochs
python -m torch.distributed.launch --nproc_per_node=8 train-crd.py /path/to/imagenet --model resnet50 --teacher resnet152 --teacher-pretrained /path/to/teacher_checkpoint --kd-loss crd --amp --epochs 300 --batch-size 256 --lr 5e-3 --opt lamb --sched cosine --weight-decay 0.02 --warmup-epochs 5 --warmup-lr 1e-6 --smoothing 0.0 --drop 0 --drop-path 0.05 --aug-repeats 3 --aa rand-m7-mstd0.5 --mixup 0.1 --cutmix 1.0 --color-jitter 0 --crop-pct 0.95 --bce-loss 0

Acknowledgement

This repository is built using the timm library, DKD, DIST, DeiT, BEiT v2, and ConvNeXt v2 repositories.

vanillakd's People

Contributors

hao840 avatar ggjy avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.