Giter VIP home page Giter VIP logo

xcit's Introduction

Cross-Covariance Image Transformer (XCiT)

PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

[arXiv] [Yannic Kilcher's video]

Linear complexity in time and memory

Our XCiT models has a linear complexity w.r.t number of patches/tokens:

Peak Memory (inference) Millisecond/Image (Inference)

Scaling to high resolution inputs

XCiT can scale to high resolution inputs both due to cheaper compute requirement as well as better adaptability to higher resolution at test time (see Figure 3 in the paper)

Detection and Instance Segmentation for Ultra high resolution images (6000x4000)

Detection and Instance segmentation result for an ultra high resolution image 6000x4000 )

XCiT+DINO: High Res. Self-Attention Visualization 🦖

Our XCiT models with self-supervised training using DINO can obtain high resolution attention maps.

xcit_dino.mp4

Self-Attention visualization per head

Below we show the attention maps for each of the 8 heads separately and we can observe that every head specializes in different semantic aspects of the scene for the foreground as well as the background.

Multi_head.mp4

Getting Started

First, clone the repo

git clone https://github.com/facebookresearch/XCiT.git

Then, you can install the required packages including: Pytorch version 1.7.1, torchvision version 0.8.2 and Timm version 0.4.8

pip install -r requirements.txt

Download and extract the ImageNet dataset. Afterwards, set the --data-path argument to the corresponding extracted ImageNet path.

For full details about all the available arguments, you can use

python main.py --help

For detection and segmentation downstream tasks, please check:


Model Zoo

We provide XCiT models pre-trained weights on ImageNet-1k.

§: distillation

Models with 16x16 patch size

Arch params Model
224 224 § 384 §
top-1 weights top-1 weights top-1 weights
xcit_nano_12_p16 3M 69.9% download 72.2% download 75.4% download
xcit_tiny_12_p16 7M 77.1% download 78.6% download 80.9% download
xcit_tiny_24_p16 12M 79.4% download 80.4% download 82.6% download
xcit_small_12_p16 26M 82.0% download 83.3% download 84.7% download
xcit_small_24_p16 48M 82.6% download 83.9% download 85.1% download
xcit_medium_24_p16 84M 82.7% download 84.3% download 85.4% download
xcit_large_24_p16 189M 82.9% download 84.9% download 85.8% download

Models with 8x8 patch size

Arch params Model
224 224 § 384 §
top-1 weights top-1 weights top-1 weights
xcit_nano_12_p8 3M 73.8% download 76.3% download 77.8% download
xcit_tiny_12_p8 7M 79.7% download 81.2% download 82.4% download
xcit_tiny_24_p8 12M 81.9% download 82.6% download 83.7% download
xcit_small_12_p8 26M 83.4% download 84.2% download 85.1% download
xcit_small_24_p8 48M 83.9% download 84.9% download 85.6% download
xcit_medium_24_p8 84M 83.7% download 85.1% download 85.8% download
xcit_large_24_p8 189M 84.4% download 85.4% download 86.0% download

XCiT + DINO Self-supervised models

Arch params k-nn linear download
xcit_small_12_p16 26M 76.0% 77.8% backbone
xcit_small_12_p8 26M 77.1% 79.2% backbone
xcit_medium_24_p16 84M 76.4% 78.8% backbone
xcit_medium_24_p8 84M 77.9% 80.3% backbone

Training

For training using a single node, use the following command

python -m torch.distributed.launch --nproc_per_node=[NUM_GPUS] --use_env main.py --model [MODEL_KEY] --batch-size [BATCH_SIZE] --drop-path [STOCHASTIC_DEPTH_RATIO] --output_dir [OUTPUT_PATH]

For example, the XCiT-S12/16 model can be trained using the following command

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model xcit_small_12_p16 --batch-size 128 --drop-path 0.05 --output_dir /experiments/xcit_small_12_p16/ --epochs [NUM_EPOCHS]

For multinode training via SLURM you can alternatively use

python run_with_submitit.py --partition [PARTITION_NAME] --nodes 2 --ngpus 8 --model xcit_small_12_p16 --batch-size 64 --drop-path 0.05 --job_dir /experiments/xcit_small_12_p16/ --epochs 400

More details for the hyper-parameters used to train the different models can be found in Table B.1 in the paper.

Evaluation

To evaluate an XCiT model using the checkpoints above or models you trained use the following command:

python main.py --eval --model <MODEL_KEY> --input-size <IMG_SIZE> [--full_crop] --pretrained <PATH/URL>

By default we use the --full_crop flag which evaluates the model with a crop ratio of 1.0 instead of 0.875 following CaiT.

For example, the command to evaluate the XCiT-S12/16 using 224x224 images:

python main.py --eval --model xcit_small_12_p16 --input-size 384 --full_crop --pretrained https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth

Acknowledgement

This repository is built using the Timm library and the DeiT repository. The self-supervised training is based on the DINO repository.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Contributing

We actively welcome your pull requests! Please see CONTRIBUTING.md and CODE_OF_CONDUCT.md for more info.

Citation

If you find this repository useful, please consider citing our work:

@article{el2021xcit,
  title={XCiT: Cross-Covariance Image Transformers},
  author={El-Nouby, Alaaeldin and Touvron, Hugo and Caron, Mathilde and Bojanowski, Piotr and Douze, Matthijs and Joulin, Armand and Laptev, Ivan and Neverova, Natalia and Synnaeve, Gabriel and Verbeek, Jakob and others},
  journal={arXiv preprint arXiv:2106.09681},
  year={2021}
}

xcit's People

Contributors

aelnouby avatar tanujdhiman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

xcit's Issues

Code parts from Swin-Transformer-Object-Detection

Hello, as you mention you borrowed some parts from the Swin-Transformer-Object-Detection repository for the detector part. It seems that some of them are mandatory, however not included in this repo. I am referring to parts relevant to training, for example EpochBasedRunnerAmp and the DistOptimizerHook. I already had previous experience with mmcv/mmdetection so I know how to include them, but are you going to officially include them in the repo?

Can't read files from the datasets folder

Thanks for your work!

I met a problem when I try to follow the steps in the readme.md

I try to set the datapath by:
python main.py --data-path 'F:\Projects\AI\ImageNet'
However, it comes:
RuntimeError: Found 0 files in subfolders of: F:\Projects\AI\ImageNet\train

I'm sure the path is right, and there are all .JPEG files in this train folder, but it seems can't read any of them.

Block-diagonal XCA?

Hello,
It seems like the XCA doesn't use separated parameters for each head and I can't find the implementation for the block-diagonal one. I'm curious about why the implementation doesn't include it? Sorry if my understanding is incorrect.

detectron2 integration?

why not using detectron2 replicate some of those exp result in detection, since this is a work from FAIR.

ImageNet-22K Models?

Greetings,

I would like first thank you for sharing the code of this amazing research.

When looking onto other competitive backbone models, such as Swin-Transformers, I see that their best results are coming from pretraiend ImageNet 22K models. Would it be possible for you guys to also release pre-trained 22K models? It would allow model to reach even higher results and possibly outperform Swin-L model that currently achieves 87.3 T1.

Question about training epochs and training logs

Hello, thank u for another simple but effective work!

In your paper, the training epochs is setting as :

We train our model for 400 epochs with the AdamW optimizer [45] using a cosine learning rate decay.

but the default epochs in your code is setting as 300 epochs and doesn't be changed in command line.

So, I'm confused about it.

By the way, could you publish your training logs?

Extremely unstable training on multiple gpus

Hi, I'm trying to reproduce the classification training results.

I tried on 2 different machines, machine A with one RTX 3090 and machine B with four A100 gpus.

The training on machine A with a single GPU is fine; see green line (with default parameters).
But on machine B with 4 gpus, it's not training properly and very erratic; see gray, yellow, teal lines (with default and custom parameters).
Purple line is DeiT training on the same machine B (default parameters).

All experiments done with --batch-size=128 (128 samples per gpu).

This is validation loss, other metrics tell the same story, some even worse.
Screen Shot 2022-01-05 at 10 32 58

Example of the commands I used:

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
    --model xcit_small_12_p16 --batch-size 128 --drop-path 0.05 --epochs 400

Anyone's seen this or know how to fix it? Many thanks.

Log Files from Training

Thank you for your awesome code!

I am hoping you might open-source the log files you have from training. Maybe the training and validation loss as a function of epoch
(and/or batch) with an estimate of the runtime?

Warning: Grad strides do not match bucket view strides.

Hi, thanks for your wonderful work. When I use Xcit for another task as the backbone, it comes the warning of Warning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. It is caused by feeding incontiguous tensors to view-style operators. While I can find some positions that cause this warning, there seem exits several different code lines which can cause this warning and I failed to find all of them. I wonder if you have also encountered this warning and do you have any advice to solve this problem?

Training on Single GPU

Thanks for the exciting work.

I am trying to finetune on my classification (imagenet) like dataset on 1 GPU using following command.

python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --model xcit_nano_12_p16 --batch-size 16 --drop-path 0.05 --output_dir experiments/xcit_nano_12_p16/ --epochs 30 --pretrained /mnt/hdd1/Projects/XCiT/xcit_nano_12_p16_224.pth

But it fails with following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [16, 1, 128]], which is output 0 of SliceBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

what could be done to resolve this? I am new to distributed training .

ERROR in requirements.txt

Getting the following error while executing requirements.txt

ERROR: Could not find a version that satisfies the requirement timm==0.4.8 (from -r requirements.txt (line 3)) (from versions: 0.1.1, 0.1.2, 0.1.4, 0.1.6, 0.1.8, 0.1.10, 0.1.12, 0.1.14, 0.1.16, 0.1.18, 0.1.20, 0.1.22, 0.1.24, 0.1.26, 0.1.28, 0.1.30, 0.2.1, 0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.4.5, 0.4.9)
ERROR: No matching distribution found for timm==0.4.8 (from -r requirements.txt (line 3))

I think we have to change the version from 0.4.8 to 0.4.9

Cannot with single gpu finetune from pretrained checkpoint?

Hello, I am having trouble loading the pretrained checkpoint of "xcit-s-12/8" with a single GPU. It seems to activate torch.distributed. Any way around it?

raise RuntimeError( RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

the solution using "tokens_norm=True" does not work for me, setting it to True or False has the same result

training logs

Hello, This is really an amazing job, can you provide the training logs, have you ever met nan in your distillation training program.

colab

great work, please add a google colab demo for inference

Welcome update to OpenMMLab 2.0

Welcome update to OpenMMLab 2.0

I am Vansin, the technical operator of OpenMMLab. In September of last year, we announced the release of OpenMMLab 2.0 at the World Artificial Intelligence Conference in Shanghai. We invite you to upgrade your algorithm library to OpenMMLab 2.0 using MMEngine, which can be used for both research and commercial purposes. If you have any questions, please feel free to join us on the OpenMMLab Discord at https://discord.gg/amFNsyUBvm or add me on WeChat (van-sin) and I will invite you to the OpenMMLab WeChat group.

Here are the OpenMMLab 2.0 repos branches:

OpenMMLab 1.0 branch OpenMMLab 2.0 branch
MMEngine 0.x
MMCV 1.x 2.x
MMDetection 0.x 、1.x、2.x 3.x
MMAction2 0.x 1.x
MMClassification 0.x 1.x
MMSegmentation 0.x 1.x
MMDetection3D 0.x 1.x
MMEditing 0.x 1.x
MMPose 0.x 1.x
MMDeploy 0.x 1.x
MMTracking 0.x 1.x
MMOCR 0.x 1.x
MMRazor 0.x 1.x
MMSelfSup 0.x 1.x
MMRotate 1.x 1.x
MMYOLO 0.x

Attention: please create a new virtual environment for OpenMMLab 2.0.

An Issue !

This particular repository is for only like detection model i.e COCO. Or can we add more models into this ?

Thanks

Loading of checkpoint fails

The semseg training command (from semantic_segmentation/README.md)

tools/dist_train.sh configs/xcit/sem_fpn/sem_fpn_xcit_small_12_p16_80k_ade20k.py 8 --work-dir /path/to/save --seed 0 --deterministic --options model.pretrained=https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth

fails with

TypeError: EncoderDecoder: XCiT: __init__() got an unexpected keyword argument 'pretrained'

loss is not decrease

I use 4x1080ti

python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model xcit_nano_12_p8 --batch-size 64 --drop-path 0.05 --output_dir ./experiments/xcit_nano_12_p8/ --epochs 100

I got this

Epoch: [1]  [  0/390]  eta: 0:27:35  lr: 0.000001  loss: 6.8866 (6.8866)  time: 4.2456  data: 2.6132  max mem: 6128
Epoch: [1]  [ 10/390]  eta: 0:07:40  lr: 0.000001  loss: 6.9260 (6.9281)  time: 1.2121  data: 0.2377  max mem: 6128
Epoch: [1]  [ 20/390]  eta: 0:06:37  lr: 0.000001  loss: 6.9244 (6.9252)  time: 0.9165  data: 0.0001  max mem: 6128
Epoch: [1]  [ 30/390]  eta: 0:06:09  lr: 0.000001  loss: 6.9226 (6.9256)  time: 0.9227  data: 0.0001  max mem: 6128
Epoch: [1]  [ 40/390]  eta: 0:05:49  lr: 0.000001  loss: 6.9321 (6.9277)  time: 0.9192  data: 0.0001  max mem: 6128
Epoch: [1]  [ 50/390]  eta: 0:05:35  lr: 0.000001  loss: 6.9378 (6.9297)  time: 0.9253  data: 0.0001  max mem: 6128
Epoch: [1]  [ 60/390]  eta: 0:05:21  lr: 0.000001  loss: 6.9406 (6.9320)  time: 0.9267  data: 0.0001  max mem: 6128
Epoch: [1]  [ 70/390]  eta: 0:05:10  lr: 0.000001  loss: 6.9406 (6.9321)  time: 0.9265  data: 0.0001  max mem: 6128
Epoch: [1]  [ 80/390]  eta: 0:04:58  lr: 0.000001  loss: 6.9337 (6.9325)  time: 0.9315  data: 0.0001  max mem: 6128
Epoch: [1]  [ 90/390]  eta: 0:04:48  lr: 0.000001  loss: 6.9340 (6.9328)  time: 0.9339  data: 0.0001  max mem: 6128
Epoch: [1]  [100/390]  eta: 0:04:38  lr: 0.000001  loss: 6.9246 (6.9323)  time: 0.9438  data: 0.0001  max mem: 6128
Epoch: [1]  [110/390]  eta: 0:04:27  lr: 0.000001  loss: 6.9255 (6.9317)  time: 0.9371  data: 0.0001  max mem: 6128
Epoch: [1]  [120/390]  eta: 0:04:17  lr: 0.000001  loss: 6.9293 (6.9317)  time: 0.9224  data: 0.0001  max mem: 6128
Epoch: [1]  [130/390]  eta: 0:04:07  lr: 0.000001  loss: 6.9322 (6.9319)  time: 0.9176  data: 0.0001  max mem: 6128
Epoch: [1]  [140/390]  eta: 0:03:57  lr: 0.000001  loss: 6.9306 (6.9320)  time: 0.9286  data: 0.0001  max mem: 6128
Epoch: [1]  [150/390]  eta: 0:03:47  lr: 0.000001  loss: 6.9294 (6.9319)  time: 0.9332  data: 0.0001  max mem: 6128
Epoch: [1]  [160/390]  eta: 0:03:38  lr: 0.000001  loss: 6.9265 (6.9313)  time: 0.9317  data: 0.0001  max mem: 6128
Epoch: [1]  [170/390]  eta: 0:03:28  lr: 0.000001  loss: 6.9265 (6.9318)  time: 0.9253  data: 0.0001  max mem: 6128
Epoch: [1]  [180/390]  eta: 0:03:18  lr: 0.000001  loss: 6.9412 (6.9319)  time: 0.9212  data: 0.0001  max mem: 6128
Epoch: [1]  [190/390]  eta: 0:03:08  lr: 0.000001  loss: 6.9273 (6.9319)  time: 0.9292  data: 0.0001  max mem: 6128
Epoch: [1]  [200/390]  eta: 0:02:59  lr: 0.000001  loss: 6.9255 (6.9318)  time: 0.9226  data: 0.0001  max mem: 6128
Epoch: [1]  [210/390]  eta: 0:02:49  lr: 0.000001  loss: 6.9305 (6.9317)  time: 0.9275  data: 0.0001  max mem: 6128
Epoch: [1]  [220/390]  eta: 0:02:40  lr: 0.000001  loss: 6.9295 (6.9314)  time: 0.9360  data: 0.0001  max mem: 6128
Epoch: [1]  [230/390]  eta: 0:02:30  lr: 0.000001  loss: 6.9290 (6.9312)  time: 0.9446  data: 0.0001  max mem: 6128
Epoch: [1]  [240/390]  eta: 0:02:21  lr: 0.000001  loss: 6.9229 (6.9305)  time: 0.9421  data: 0.0001  max mem: 6128
Epoch: [1]  [250/390]  eta: 0:02:11  lr: 0.000001  loss: 6.9263 (6.9310)  time: 0.9283  data: 0.0001  max mem: 6128
Epoch: [1]  [260/390]  eta: 0:02:02  lr: 0.000001  loss: 6.9225 (6.9305)  time: 0.9232  data: 0.0001  max mem: 6128
Epoch: [1]  [270/390]  eta: 0:01:52  lr: 0.000001  loss: 6.9225 (6.9307)  time: 0.9220  data: 0.0001  max mem: 6128
Epoch: [1]  [280/390]  eta: 0:01:43  lr: 0.000001  loss: 6.9359 (6.9309)  time: 0.9205  data: 0.0001  max mem: 6128
Epoch: [1]  [290/390]  eta: 0:01:33  lr: 0.000001  loss: 6.9323 (6.9307)  time: 0.9232  data: 0.0001  max mem: 6128
Epoch: [1]  [300/390]  eta: 0:01:24  lr: 0.000001  loss: 6.9245 (6.9304)  time: 0.9327  data: 0.0001  max mem: 6128
Epoch: [1]  [310/390]  eta: 0:01:15  lr: 0.000001  loss: 6.9237 (6.9304)  time: 0.9280  data: 0.0001  max mem: 6128
Epoch: [1]  [320/390]  eta: 0:01:05  lr: 0.000001  loss: 6.9333 (6.9307)  time: 0.9234  data: 0.0001  max mem: 6128
Epoch: [1]  [330/390]  eta: 0:00:56  lr: 0.000001  loss: 6.9372 (6.9308)  time: 0.9362  data: 0.0001  max mem: 6128
Epoch: [1]  [340/390]  eta: 0:00:46  lr: 0.000001  loss: 6.9314 (6.9306)  time: 0.9338  data: 0.0001  max mem: 6128
Epoch: [1]  [350/390]  eta: 0:00:37  lr: 0.000001  loss: 6.9309 (6.9308)  time: 0.9319  data: 0.0001  max mem: 6128
Epoch: [1]  [360/390]  eta: 0:00:28  lr: 0.000001  loss: 6.9258 (6.9307)  time: 0.9357  data: 0.0001  max mem: 6128
Epoch: [1]  [370/390]  eta: 0:00:18  lr: 0.000001  loss: 6.9239 (6.9303)  time: 0.9330  data: 0.0002  max mem: 6128
Epoch: [1]  [380/390]  eta: 0:00:09  lr: 0.000001  loss: 6.9205 (6.9301)  time: 0.9424  data: 0.0001  max mem: 6128
Epoch: [1]  [389/390]  eta: 0:00:00  lr: 0.000001  loss: 6.9206 (6.9300)  time: 0.9389  data: 0.0001  max mem: 6128

Finetuning details

Hello; this is a great work. I would like to take advantage of the models lower vram requirement to deploy these models on edge.

However i would like to ask for resources on how to finetune the models with our data. Does finetuning follow the standart model of replacing the classification head (the final connected layer maybe?); and then applying training with a lower learning rate (what would you advise as a general baseline?).

Thanks in advance for any pointers, again great work!

Fine-tuning configurations

Thank you for a great work.
I'm trying to reproduce transfer learning results, but I'm not sure about the fine-tuning configuration.
I read the issue below, but is it all (just smaller lr=5e-5 for classification)?
#9 (comment)

Would you inform me if there should be additional changes?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.