Giter VIP home page Giter VIP logo

maxvit-1's Introduction

MaxViT: Multi-Axis Vision Transformer

License: MIT

Unofficial PyTorch reimplementation of the paper MaxViT: Multi-Axis Vision Transformer by Zhengzhong Tu et al. (Google Research).

1

Figure taken from paper.

Note timm offers pre-trained MaxViT weights on ImageNet!

Installation

You can simply install the MaxViT implementation as a Python package by using pip.

pip install git+https://github.com/ChristophReich1996/MaxViT

Alternatively, you can clone the repository and use the implementation in maxvit directly in your project.

This implementation only relies on PyTorch and Timm ( see requirements.txt).

Usage

This implementation provides the pre-configured models of the paper (tiny, small, base, and large 224 X 224), which can be used as:

import torch
import maxvit

# Tiny model
network: maxvit.MaxViT = maxvit.max_vit_tiny_224(num_classes=1000)
input = torch.rand(1, 3, 224, 224)
output = network(input)

# Small model
network: maxvit.MaxViT = maxvit.max_vit_small_224(num_classes=365, in_channels=1)
input = torch.rand(1, 1, 224, 224)
output = network(input)

# Base model
network: maxvit.MaxViT = maxvit.max_vit_base_224(in_channels=4)
input = torch.rand(1, 4, 224, 224)
output = network(input)

# Large model
network: maxvit.MaxViT = maxvit.max_vit_large_224()
input = torch.rand(1, 3, 224, 224)
output = network(input)

To accesses the named weights of the network which are not recommended being used with weight decay call nwd: Set[str] = network.no_weight_decay().

In case you want to use a custom configuration you can use the MaxViT class. The constructor method takes the following parameters.

Parameter Description Type
in_channels Number of input channels to the convolutional stem. Default 3 int, optional
depths Depth of each network stage. Default (2, 2, 5, 2) Tuple[int, ...], optional
channels Number of channels in each network stage. Default (64, 128, 256, 512) Tuple[int, ...], optional
num_classes Number of classes to be predicted. Default 1000 int, optional
embed_dim Embedding dimension of the convolutional stem. Default 64 int, optional
num_heads Number of attention heads. Default 32 int, optional
grid_window_size Grid/Window size to be utilized. Default (7, 7) Tuple[int, int], optional
attn_drop Dropout ratio of attention weight. Default: 0.0 float, optional
drop Dropout ratio of output. Default: 0.0 float, optional
drop_path Dropout ratio of path. Default: 0.0 float, optional
mlp_ratio Ratio of mlp hidden dim to embedding dim. Default: 4.0 float, optional
act_layer Type of activation layer to be utilized. Default: nn.GELU Type[nn.Module], optional
norm_layer Type of normalization layer to be utilized. Default: nn.BatchNorm2d Type[nn.Module], optional
norm_layer_transformer Normalization layer in Transformer. Default: nn.LayerNorm Type[nn.Module], optional
global_pool Global polling type to be utilized. Default "avg" str, optional

Disclaimer

This is a very experimental implementation only based on the MaxViT paper. Since an official implementation of the MaxViT is not yet published, it is not possible to say to which extent this implementation might differ from the original one. If you have any issues with this implementation please raise an issue.

Reference

@article{Liu2021,
    title={{MaxViT: Multi-Axis Vision Transformer}},
    author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan 
            and Li, Yinxiao}
    journal={arXiv preprint arXiv:2204.01697},
    year={2022}
}

maxvit-1's People

Contributors

christophreich1996 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.