Giter VIP home page Giter VIP logo

s-chh / pytorch-vision-transformer-vit-mnist-cifar10 Goto Github PK

View Code? Open in Web Editor NEW
62.0 2.0 10.0 989 KB

Simplified Pytorch implementation of Vision Transformer (ViT) for small datasets like MNIST, FashionMNIST, SVHN and CIFAR10.

Python 98.86% Shell 1.14%
vision-transformer vit transformer vit-mnist transformer-mnist pytorch-vit scratch simple vit-scratch vit-fashionmnist vit-svhn transformer-cifar10 vit-cifar vit-cifar10 vit-simple

pytorch-vision-transformer-vit-mnist-cifar10's Introduction

Vision Transformer for MNIST and CIFAR10

Simplified Scratch Pytorch implementation of Vision Transformer (ViT) with detailed steps (Refer to model.py).

  • Scaled-down version of the original ViT architecture from An Image is Worth 16X16 Words.
  • Has only 200k-800k parameters depending upon the embedding dimension (Original ViT-Base has 86 million).
  • Works with small datasets by using a smaller patch size of 4.
  • Supported datasets: MNIST, FashionMNIST, SVHN, and CIFAR10.



Run commands (also available in scripts.sh):

Dataset Run command Test Acc
MNIST python main.py --dataset mnist --epochs 100 99.5
Fashion MNIST python main.py --dataset fmnist 92.3
SVHN python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 96.2
CIFAR10 python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 82.5 (86.3 with RandAug)



Transformer Config:

Config MNIST and FMNIST SVHN and CIFAR10
Input Size 1 X 28 X 28 3 X 32 X 32
Patch Size 4 4
Sequence Length 7*7 = 49 8*8 = 64
Embedding Size 64 128
Parameters 210k 820k
Num of Layers 6 6
Num of Heads 4 4
Forward Multiplier 2 2
Dropout 0.1 0.1

pytorch-vision-transformer-vit-mnist-cifar10's People

Contributors

s-chh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch-vision-transformer-vit-mnist-cifar10's Issues

Question about running

Thanks for your great work!

I want to run your code, but there is an error.

Traceback (most recent call last):
  File "main.py", line 50, in <module>
    main(args)
  File "main.py", line 10, in main
    solver = Solver(args)
  File "/home/ikenaga/xiexiaomeng/PyTorch-ViT-MNIST-main/solver.py", line 15, in __init__
    self.model = VisionTransformer(args).cuda()
  File "/home/ikenaga/xiexiaomeng/PyTorch-ViT-MNIST-main/model.py", line 112, in __init__
    super(Transformer, self).__init__()
NameError: name 'Transformer' is not defined

Could you please show me your soft environment?
My soft environment is as follows:
torch==1.8.1
Python==3.7

Rotary Position Embedding

I really enjoy reading your code; it is very clear and easy to understand, while also achieving top results in the benchmarks! I would like to explore Rotary Position Embedding (used by LLAMA models) using your example. Have you already tried it? Do you have any exprience with it used in ViT?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.