Giter VIP home page Giter VIP logo

vision-transformer-pytorch's Introduction

Vision Transformer - Pytorch

Pytorch implementation of Vision Transformer. Pretrained pytorch weights are provided which are converted from original jax/flax weights. This is a project of the ASYML family and CASL.

Introduction

Figure 1 from paper

Pytorch implementation of paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. We provide the pretrained pytorch weights which are converted from pretrained jax/flax models. We also provide fine-tune and evaluation script. Similar results as in original implementation are achieved.

Installation

Create environment:

conda create --name vit --file requirements.txt
conda activate vit

Available Models

We provide pytorch model weights, which are converted from original jax/flax wieghts. You can download them and put the files under 'weights/pytorch' to use them.

Otherwise you can download the original jax/flax weights and put the fimes under 'weights/jax' to use them. We'll convert the weights for you online.

Datasets

Currently three datasets are supported: ImageNet2012, CIFAR10, and CIFAR100. To evaluate or fine-tune on these datasets, download the datasets and put them in 'data/dataset_name'.

More datasets will be supported.

Fine-Tune/Train

python src/train.py --exp-name ft --n-gpu 4 --tensorboard  --model-arch b16 --checkpoint-path weights/pytorch/imagenet21k+imagenet2012_ViT-B_16.pth --image-size 384 --batch-size 32 --data-dir data/ --dataset CIFAR10 --num-classes 10 --train-steps 10000 --lr 0.03 --wd 0.0

Evaluation

Make sure you have downloaded the pretrained weights either in '.npy' format or '.pth' format

python src/eval.py --model-arch b16 --checkpoint-path weights/jax/imagenet21k+imagenet2012_ViT-B_16.npy --image-size 384 --batch-size 128 --data-dir data/ImageNet --dataset ImageNet --num-classes 1000

Results and Models

Pretrained Results on ImageNet2012

upstream model dataset orig. jax acc pytorch acc model link
imagenet21k ViT-B_16 imagenet2012 84.62 83.90 checkpoint
imagenet21k ViT-B_32 imagenet2012 81.79 81.14 checkpoint
imagenet21k ViT-L_16 imagenet2012 85.07 84.94 checkpoint
imagenet21k ViT-L_32 imagenet2012 82.01 81.03 checkpoint

Fine-Tune Results on CIFAR10/100

Due to limited GPU resources, the fine-tune results are obtained by using a batch size of 32 which may impact the performance a bit.

upstream model dataset orig. jax acc pytorch acc
imagenet21k ViT-B_16 CIFAR10 98.92 98.90
imagenet21k ViT-B_16 CIFAR100 92.26 91.65

TODO

  • Colab
  • Integrated into Texar

Acknowledge

  1. https://github.com/google-research/vision_transformer
  2. https://github.com/lucidrains/vit-pytorch
  3. https://github.com/kamalkraj/Vision-Transformer

Contributing

Issues and Pull Requests are welcome for improving this repo. Please follow the contribution guide

License

Apache License 2.0

Supporting Companies and Universities

                  

vision-transformer-pytorch's People

Contributors

hhhhhhao avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.