Giter VIP home page Giter VIP logo

google-research / maxvit Goto Github PK

View Code? Open in Web Editor NEW
434.0 9.0 30.0 2.67 MB

[ECCV 2022] Official repository for "MaxViT: Multi-Axis Vision Transformer". SOTA foundation models for classification, detection, segmentation, image quality, and generative modeling...

License: Apache License 2.0

Python 8.65% Jupyter Notebook 91.35%
architecture classification cnn computer-vision image image-processing mlp object-detection transformer transformer-architecture

maxvit's Introduction

MaxViT: Multi-Axis Vision Transformer (ECCV 2022)

Paper Tutorial In Colab video

This repository hosts the official TensorFlow implementation of MAXViT models:

MaxViT: Multi-Axis Vision Transformer. ECCV 2022.
Zhengzhong Tu, Hossein Talebi, Han Zhang, Feng Yang, Peyman Milanfar, Alan Bovik, and Yinxiao Li
Google Research, University of Texas at Austin

Disclaimer: This is not an officially supported Google product.

News:

  • May, 2023: MaxViT is officially released in Tensorflow model garden to support training!
  • Oct 12, 2022: Added the remaining ImageNet-1K and -21K checkpoints.
  • Oct 4, 2022: A list of updates
    • Added MaxViTTiny and MaxViTSmall checkpoints.
    • Added a Colab tutorial.
  • Sep 8, 2022: our Google AI blog covering both MaxViT and MAXIM is live.
  • Sep 7, 2022: @rwightman released a few small model weights in timm. Achieves even better results than our paper. See more here.
  • Aug 26, 2022: our MaxViT models have been implemented in timm (pytorch-image-models). Kudos to @rwightman!
  • July 21, 2022: Initial code release of MaxViT models: accepted to ECCV'22.
  • Apr 6, 2022: MaxViT has been implemented by @lucidrains: vit-pytorch 😱 🤯
  • Apr 4, 2022: initial uploads to Arxiv

MaxViT Models

MaxViT is a family of hybrid (CNN + ViT) image classification models, that achieves better performances across the board for both parameter and FLOPs efficiency than both SoTA ConvNets and Transformers. They can also scale well on large dataset sizes like ImageNet-21K. Notably, due to the linear-complexity of the grid attention used, MaxViT is able to ''see'' globally throughout the entire network, even in earlier, high-resolution stages.

MaxViT meta-architecture:

Results on ImageNet-1k train and test:

Results on ImageNet-21k and JFT pre-trained models:

Colab Demo

We have released a Google Colab Demo on the tutorials of how to run MaxViT on images. Try it here Open In Colab

Pretrained MaxViT Checkpoints

We have provided a list of results and checkpoints as follows:

Name Resolution Top1 Acc. #Params FLOPs Model
MaxViT-T 224x224 83.62% 31M 5.6B ckpt
MaxViT-T 384x384 85.24% 31M 17.7B ckpt
MaxViT-T 512x512 85.72% 31M 33.7B ckpt
MaxViT-S 224x224 84.45% 69M 11.7B ckpt
MaxViT-S 384x384 85.74% 69M 36.1B ckpt
MaxViT-S 512x512 86.19% 69M 67.6B ckpt
MaxViT-B 224x224 84.95% 119M 24.2B ckpt
MaxViT-B 384x384 86.34% 119M 74.2B ckpt
MaxViT-B 512x512 86.66% 119M 138.5B ckpt
MaxViT-L 224x224 85.17% 212M 43.9B ckpt
MaxViT-L 384x384 86.40% 212M 133.1B ckpt
MaxViT-L 512x512 86.70% 212M 245.4B ckpt

Here are a list of ImageNet-21K pretrained and ImageNet-1K finetuned models:

Name Resolution Top1 Acc. #Params FLOPs 21k model 1k model
MaxViT-B 224x224 - 119M 24.2B ckpt -
MaxViT-B 384x384 - 119M 74.2B - ckpt
MaxViT-B 512x512 - 119M 138.5B - ckpt
MaxViT-L 224x224 - 212M 43.9B ckpt -
MaxViT-L 384x384 - 212M 133.1B - ckpt
MaxViT-L 512x512 - 212M 245.4B - ckpt
MaxViT-XL 224x224 - 475M 97.8B ckpt -
MaxViT-XL 384x384 - 475M 293.7B - ckpt
MaxViT-XL 512x512 - 475M 535.2B - ckpt

Citation

Should you find this repository useful, please consider citing:

@article{tu2022maxvit,
  title={MaxViT: Multi-Axis Vision Transformer},
  author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
  journal={ECCV},
  year={2022},
}

Other Related Works

  • MAXIM: Multi-Axis MLP for Image Processing, CVPR 2022. Paper | Code
  • CoBEVT: Cooperative Bird's Eye View Semantic Segmentation with Sparse Transformers, CoRL 2022. Paper | Code
  • Improved Transformer for High-Resolution GANs, NeurIPS 2021. Paper | Code
  • CoAtNet: Marrying Convolution and Attention for All Data Sizes, NeurIPS 2021. Paper
  • EfficientNetV2: Smaller Models and Faster Training, ICML 2021. Paper | Code

Acknowledgement: This repository is built on the EfficientNets and CoAtNet.

maxvit's People

Contributors

yinxiaoli avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

maxvit's Issues

How to load imagenet21k pretrained weights

I have tried to load the model.ckpt through model.load_weights() and also followed the tutorial in using eval_driver but was not able to load them correctly. It showed AssertionError: Some objects had attributes which were not restored.

Add preprocessing method that takes the raw image tensor

I have tried with preprocessing method:

img_ = eval_driver.get_preprocess_fn()(tf.io.read_file(path))

But get constantly errors like:

{function_node __wrapped__ExtractJpegShape_device_/job:localhost/replica:0/task:0/device:CPU:0}} Invalid JPEG data, size 442707 [Op:ExtractJpegShape]
Not a JPEG file: starts with 0x89 0x50
{{function_node __wrapped__ExtractJpegShape_device_/job:localhost/replica:0/task:0/device:CPU:0}} Invalid JPEG data, size 328248 [Op:ExtractJpegShape]
{{function_node __wrapped__ExtractJpegShape_device_/job:localhost/replica:0/task:0/device:CPU:0}} Invalid JPEG data, size 156137 [Op:ExtractJpegShape]
Not a JPEG file: starts with 0x89 0x50
Not a JPEG file: starts with 0x89 0x50
{{function_node __wrapped__ExtractJpegShape_device_/job:localhost/replica:0/task:0/device:CPU:0}} Invalid JPEG data, size 258371 [Op:ExtractJpegShape

Gradients do not exist for variables 'maxvit/block_00_00/attention/relative_bias:0'

  • In maxvit.py class Attention #L216-L238 implementation, we have self.relative_bias = self.add_weight(...) added as weight first, then rearranged self.reindexed_bias = attn_utils.reindex_2d_einsum_lookup(...) in build, not in call. I'm not sure about this, but in my understanding, the second rearranged self.reindexed_bias is a copy of the first initialized weights, and thus the original weights will not get updated during training. I think the self.reindexed_bias = attn_utils.reindex_2d_einsum_lookup(...) operation should be moved into def call block, like into maxvit.py#L258.
  • This Colab maxvit_train_test.ipynb is a basic training test of whether relative_bias is changed during training. Anyway, maybe it's I'm not using it correctly. please help confirm.
  • BTW, this is my keras implementation of MaxViT Github keras_cv_attention_models/maxvit. Currently It's only MaxViT_T weights ported, as the ported relative_bias seems not right...

Maxvit local global

为什么计算局部注意力时,需要把特征图变换成 (H/P × W/P, P², C) 这个形状,即将P²放在倒数第二个维度?

而计算全局注意力时,则需要把特征图变换成 (G², H/G × W/G, C) 这个形状,然后再交换 【倒数第二个维度】 和 【倒数第三个维度】 的顺序,即变成 (H/G × W/G, G², C),既然这种形式和局部形式相同,为什么不直接进行相同的变换呢,而是再去额外的交换维度?

maxvit-gan?

how did you use max vit in gan and object detection and segmentation? Is code available? And how to fine tune this model on different data set, is it documented?

Is there a pytorch implementation?

HI Sir, I am new here, thanks for your great work. I am not very familiar with the syntax of TensorFlow. Is there a model implemented by pytorch?

Bug: get_config() not in MaxViT

I got the following error when trying to train and save a model:

NotImplementedError:
Layer MaxViT has arguments ['config']
in `__init__` and therefore must override `get_config()`.

Example:

class CustomLayer(keras.layers.Layer):
    def __init__(self, arg1, arg2):
        super().__init__()
        self.arg1 = arg1
        self.arg2 = arg2

    def get_config(self):
        config = super().get_config()
        config.update({
            "arg1": self.arg1,
            "arg2": self.arg2,
        })
        return config

Get The process cannot access the file because it is being used by another process error on Windows

I have tried install MaxVit on Windows, but get:

C:\Projects*\venv\lib\site-packages\setuptools\command\install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
warnings.warn(
C:\Projects*
\venv\lib\site-packages\setuptools\command\easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
warnings.warn(
zip_safe flag not set; analyzing archive contents...
error: [WinError 32] The process cannot access the file because it is being used by another process: 'c:\projects\
**\venv\lib\site-packages\maxvit-1.0.0-py3.10.egg'

COCO Training Details

Thanks for the wonderful paper - it was a pleasure to read!

Could you kindly elaborate a bit more on the COCO training details? In particular, I was wondering about the following three points:

  • What augmentations were used? Just large or small scale jittering and maybe flipping (left-right)?
  • What learning rate schedule was used?
  • Was the same learning rate for used for the different layers?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.