Giter VIP home page Giter VIP logo

dino's Introduction

๐Ÿ†• Please check out our more recent DINOv2 effort in the same line of work.

Self-Supervised Vision Transformers with DINO

PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supervised Vision Transformers.
[blogpost] [arXiv] [Yannic Kilcher's video]

DINO illustration

Pretrained models

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the backbone in onnx format, as well as detailed arguments and training/evaluation logs. Note that DeiT-S and ViT-S names refer exactly to the same architecture.

arch params k-nn linear download
ViT-S/16 21M 74.5% 77.0% backbone only full ckpt onnx args logs eval logs
ViT-S/8 21M 78.3% 79.7% backbone only full ckpt onnx args logs eval logs
ViT-B/16 85M 76.1% 78.2% backbone only full ckpt onnx args logs eval logs
ViT-B/8 85M 77.4% 80.1% backbone only full ckpt onnx args logs eval logs
ResNet-50 23M 67.5% 75.3% backbone only full ckpt onnx args logs eval logs

We also release XCiT models ([arXiv] [code]) trained with DINO:

arch params k-nn linear download
xcit_small_12_p16 26M 76.0% 77.8% backbone only full ckpt args logs eval
xcit_small_12_p8 26M 77.1% 79.2% backbone only full ckpt args logs eval
xcit_medium_24_p16 84M 76.4% 78.8% backbone only full ckpt args logs eval
xcit_medium_24_p8 84M 77.9% 80.3% backbone only full ckpt args logs eval

Pretrained models on PyTorch Hub

import torch
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
xcit_small_12_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p16')
xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8')
xcit_medium_24_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16')
xcit_medium_24_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

Training

Documentation

Please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the args column of the pretrained models section. For a glimpse at the full documentation of DINO training please run:

python main_dino.py --help

Vanilla DINO training ๐Ÿฆ•

Run DINO with ViT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide training and linear evaluation logs (with batch size 256 at evaluation time) for this run to help reproducibility.

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Multi-node training

We use Slurm and submitit (pip install submitit). To train on 2 nodes with 8 GPUs each (total 16 GPUs):

python run_with_submitit.py --nodes 2 --ngpus 8 --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
DINO with ViT-base network.
python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base  --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Boosting DINO performance ๐Ÿฆ–

You can improve the performance of the vanilla run by:

  • training for more epochs: --epochs 300,
  • increasing the teacher temperature: --teacher_temp 0.07 --warmup_teacher_temp_epochs 30.
  • removing last layer normalization (only safe with --arch vit_small): --norm_last_layer false,
Full command.
python run_with_submitit.py --arch vit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

The resulting pretrained model should reach 73.3% on k-NN eval and 76.0% on linear eval. Training time is 2.6 days with 16 GPUs. We provide training and linear evaluation logs (with batch size 256 at evaluation time) for this run to help reproducibility.

ResNet-50 and other convnets trainings

This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide training logs and final checkpoint for this run.

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Self-attention visualization

You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:

python visualize_attention.py
Self-attention from a Vision Transformer with 8x8 patches trained with DINO

Self-attention video generation

You can generate videos like the one on the blog post with video_generation.py.

example.mp4

Extract frames from input video and generate attention video:

python video_generation.py  --pretrained_weights dino_deitsmall8_pretrain.pth \
    --input_path input/video.mp4 \
    --output_path output/ \
    --fps 25

Use folder of frames already extracted and generate attention video:

python video_generation.py  --pretrained_weights dino_deitsmall8_pretrain.pth \
    --input_path output/frames/ \
    --output_path output/ \
    --resize 256 \

Only generate video from folder of attention maps images:

python video_generation.py --input_path output/attention \
    --output_path output/ \
    --video_only \
    --video_format avi

Evaluation: k-NN classification on ImageNet

To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet

If you choose not to specify --pretrained_weights, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet 

Evaluation: Linear classification on ImageNet

To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet

We release the logs and weights from evaluating the different models:

arch top-1 ImageNet linear evaluation
ViT-S/16 77.0% linear weights logs
ViT-S/8 79.7% linear weights logs
ViT-B/16 78.2% linear weights logs
ViT-B/8 80.1% linear weights logs
xcit_small_12_p16 77.8% linear weights logs
xcit_small_12_p8 79.2% linear weights logs
xcit_medium_24_p16 78.8% linear weights logs
xcit_medium_24_p8 80.3% linear weights logs
ResNet-50 75.3% linear weights logs

You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines:

python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train
python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train
python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train

Evaluation: DAVIS 2017 Video object segmentation

Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment.

Step 1: Prepare DAVIS 2017 data

cd $HOME
git clone https://github.com/davisvideochallenge/davis-2017 && cd davis-2017
./data/get_davis.sh

Step 2: Video object segmentation

python eval_video_segmentation.py --data_path $HOME/davis-2017/DAVIS/ --output_dir /path/to/saving_dir

Step 3: Evaluate the obtained segmentation

git clone https://github.com/davisvideochallenge/davis2017-evaluation $HOME/davis2017-evaluation
python $HOME/davis2017-evaluation/evaluation_method.py --task semi-supervised --results_path /path/to/saving_dir --davis_path $HOME/davis-2017/DAVIS/

Evaluation: Image Retrieval on revisited Oxford and Paris

Step 1: Prepare revisited Oxford and Paris by following this repo.

Step 2: Image retrieval (if you do not specify weights with --pretrained_weights then by default DINO weights pretrained on Google Landmark v2 dataset will be used).

Paris:

python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 512 --multiscale 1 --data_path /path/to/revisited_paris_oxford/ --dataset rparis6k

Oxford:

python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 224 --multiscale 0 --data_path /path/to/revisited_paris_oxford/ --dataset roxford5k

Evaluation: Copy detection on Copydays

Step 1: Prepare Copydays dataset.

Step 2 (opt): Prepare a set of image distractors and a set of images on which to learn the whitening operator. In our paper, we use 10k random images from YFCC100M as distractors and 20k random images from YFCC100M (different from the distractors) for computing the whitening operation.

Step 3: Run copy detection:

python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_copy_detection.py --data_path /path/to/copydays/ --whitening_path /path/to/whitening_data/ --distractors_path /path/to/distractors/

We report result on the strong subset. For example in the stdout from the command above we get: eval on strong mAP=0.858.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find this repository useful, please consider giving a star โญ and citation ๐Ÿฆ–:

@inproceedings{caron2021emerging,
  title={Emerging Properties in Self-Supervised Vision Transformers},
  author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e  and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
  booktitle={Proceedings of the International Conference on Computer Vision (ICCV)},
  year={2021}
}

dino's People

Contributors

aquadzn avatar mathildecaron31 avatar mgpadalkar avatar piotr-bojanowski avatar timdarcet avatar user1234554321 avatar wuyongfa-genius avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dino's Issues

Question: Supervised Attention Visualization

I was wondering about the attention maps used for visualizing the supervised training model. As far as I can understand, in the source code last layer attention weights are used to visualize for saliency masks. Is this same approach used for visualizing the supervised model for the visual comparison we have in Figure 4 of the paper?

If so, in papers such as Quantifying Attention Flow in Transformers it is argued that final attention maps can't be directly mapped back to input tokens in a meaningful way since information is mixed during forward propagation of multiple self-attention blocks. What are your views on this?

Definitely, having saliency maps as a byproduct of self-supervised training is way more valuable than supervised training in the sense of zero shot learning. I was curious if last layer attention maps are used during supervised visualizations wouldn't it be more fair to use an approach like attention-flow instead? Also, would using this approach give different results for ViTs trained with DINO? Also, I am not sure if we can make sense of different heads with attention flow approach, and its pretty cool to see that different heads are able to localize into different objects in the case of DINO.

Thank you! :)

Error finetuning from pretrained checkpoint

Hi all, I'm running into an error when trying to fine-tune from one of the pretrained checkpoints.

Code

!mkdir "$output"
!wget -q -O "$output/checkpoint.pth" https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth

!python -m torch.distributed.launch \
  --nproc_per_node=1 ./dino/main_dino.py \
  --arch deit_small \
  --data_path "$input" \
  --output_dir "$output"

Error

| distributed init (rank 0): env://
git:
  sha: 8aa93fdc90eae4b183c4e3c005174a9f634ecfbf, status: clean, branch: main

arch: deit_small
batch_size_per_gpu: 64
...
...
Student and Teacher are built: they are both deit_small network.
Loss, optimizer and schedulers ready.
Found checkpoint at ./drive/MyDrive/DINO/checkpoint.pth
=> failed to load student from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load teacher from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load optimizer from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load fp16_scaler from checkpoint './drive/MyDrive/DINO/checkpoint.pth'
=> failed to load dino_loss from checkpoint './drive/MyDrive/DINO/checkpoint.pth'

Any suggestions would be very much appreciated.

Error when loading pretrained checkpoint for finetunning

Dear authors,

thank you very much for this repo. I would like to use the pre-trained weights and finetune the network for a different dataset using self-supervised learning with DINO.
However, when I try to use the code, I cannot load the optimizer's state and I get the following output:

Found checkpoint at ./dino_ft_workdir/checkpoint.pth
=> loaded student from checkpoint './dino_ft_workdir/checkpoint.pth' with msg <All keys matched successfully>
=> loaded teacher from checkpoint './dino_ft_workdir/checkpoint.pth' with msg <All keys matched successfully>
=> loaded optimizer from checkpoint './dino_ft_workdir/checkpoint.pth'
=> failed to load fp16_scaler from checkpoint './dino_ft_workdir/checkpoint.pth'
=> failed to load dino_loss from checkpoint './dino_ft_workdir/checkpoint.pth'

fp16_scaler and dino_loss are not in the checkpoint so it is clear why they are not loaded.
I found out that the problem with the optimizer is caused by this:

ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group.

Is there anyone who would be able to help me?

Thank you very much in advance!

NameError: name 'max_accuracy' is not defined

I have tried to run the eval_linear.py after training dino on a custom dataset. I get the folowing error:
`
Traceback (most recent call last):
File "/home/ubuntu/dino/eval_linear.py", line 250, in
eval_linear(args)
File "/home/ubuntu/dino/eval_linear.py", line 142, in eval_linear
"Top-1 test accuracy: {acc:.1f}".format(acc=max_accuracy))
NameError: name 'max_accuracy' is not defined
Killing subprocess 44659
Traceback (most recent call last):
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/distributed/launch.py", line 340, in
main()
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/distributed/launch.py", line 326, in main
sigkill_handler(signal.SIGTERM, None) # not coming back
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/ubuntu/anaconda3/envs/pytorch/bin/python', '-u', 'eval_linear.py', '--local_rank=0', '--data_path', 'images', '--pretrained_weights', 'runs/checkpoint.pth']' returned non-zero exit status 1.

`

After skimming through the lines above it, I believe it should be best_acc rather than max_accuracy ?
Note: I have changed the dataset class (pytorch dataset) so as to work on my custom dataset but I believe the error still stands as I tried finding max_accuracy in the entire file and found only one occurrence (i.e., not defined earlier).

ImageNet dataset

Hello,
Very interesting work! I was just wondering about the dataset used. Is it the full ImageNet or just the ILSVRC subset?
Thanks!
Tim

About the immediate layer feature

How to extract the immediate layer feature in Vit_b8 model?
I tried:
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
modules = list(model.children())[:-2]
model = nn.Sequential(*modules)
but model(x.to(device)) error occur: forward() takes 1 argument but 2 were given

Error using visualize_attention.py. The size of tensor a (3234) must match the size of tensor b (3181) at non-singleton dimension 1

Hi all, I am trying to execute visualize_attention.py with default pretrained weights on my own image as below

!python visualize_attention.py --image_path 'test/finalImg_249.png'

I get size mistamatch error. Could you please let me know what changes needs to be done here?

Error stack trace:

Please use the --pretrained_weights argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3458: UserWarning: Default upsampling behavior when mode=bicubic is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.

"See the documentation of nn.Upsample for details.".format(mode)
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3503: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details.
"The default behavior for interpolate/upsample with float scale_factor changed "

Traceback (most recent call last):
File "visualize_attention.py", line 162, in
attentions = model.forward_selfattention(img.to(device))
File "~/dino/vision_transformer.py", line 246, in forward_selfattention
x = x + pos_embed

RuntimeError: The size of tensor a (3234) must match the size of tensor b (3181) at non-singleton dimension 1

Image details:
import cv2
img = cv2.imread('finalImg_249.png')
print (img.shape) #output: (427, 488, 3)

Download error ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

Hi, I am downloading small model like in README

import torch
deits8 = torch.hub.load('facebookresearch/dino:main', 'dino_deits8

and getting this error

Downloading: "https://github.com/facebookresearch/dino/archive/main.zip" to C:\Users\Igor/.cache\torch\hub\main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" to C:\Users\Igor/.cache\torch\hub\checkpoints\dino_deitsmall8_pretrain.pth
11%
8.96M/82.7M [00:14<01:56, 663kB/s]

---------------------------------------------------------------------------
ConnectionResetError                      Traceback (most recent call last)
<ipython-input-22-fae2c58f62a6> in <module>
      1 import torch
----> 2 deits8 = torch.hub.load('facebookresearch/dino:main', 'dino_deits8')

~\anaconda3\lib\site-packages\torch\hub.py in load(repo_or_dir, model, *args, **kwargs)
    368         repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose)
    369 
--> 370     model = _load_local(repo_or_dir, model, *args, **kwargs)
    371     return model
    372 

~\anaconda3\lib\site-packages\torch\hub.py in _load_local(hubconf_dir, model, *args, **kwargs)
    397 
    398     entry = _load_entry_from_hubconf(hub_module, model)
--> 399     model = entry(*args, **kwargs)
    400 
    401     sys.path.remove(hubconf_dir)

~/.cache\torch\hub\facebookresearch_dino_main\hubconf.py in dino_deits8(pretrained, **kwargs)
     30     model = vits.__dict__["deit_small"](patch_size=8, num_classes=0, **kwargs)
     31     if pretrained:
---> 32         state_dict = torch.hub.load_state_dict_from_url(
     33             url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
     34             map_location="cpu",

~\anaconda3\lib\site-packages\torch\hub.py in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name)
    553             r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
    554             hash_prefix = r.group(1) if r else None
--> 555         download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    556 
    557     if _is_legacy_zip_format(cached_file):

~\anaconda3\lib\site-packages\torch\hub.py in download_url_to_file(url, dst, hash_prefix, progress)
    445                   unit='B', unit_scale=True, unit_divisor=1024) as pbar:
    446             while True:
--> 447                 buffer = u.read(8192)
    448                 if len(buffer) == 0:
    449                     break

~\anaconda3\lib\http\client.py in read(self, amt)
    456             # Amount is given, implement using readinto
    457             b = bytearray(amt)
--> 458             n = self.readinto(b)
    459             return memoryview(b)[:n].tobytes()
    460         else:

~\anaconda3\lib\http\client.py in readinto(self, b)
    500         # connection, and the user is reading more bytes than will be provided
    501         # (for example, reading in 1k chunks)
--> 502         n = self.fp.readinto(b)
    503         if not n and b:
    504             # Ideally, we would raise IncompleteRead if the content-length

~\anaconda3\lib\socket.py in readinto(self, b)
    667         while True:
    668             try:
--> 669                 return self._sock.recv_into(b)
    670             except timeout:
    671                 self._timeout_occurred = True

~\anaconda3\lib\ssl.py in recv_into(self, buffer, nbytes, flags)
   1239                   "non-zero flags not allowed in calls to recv_into() on %s" %
   1240                   self.__class__)
-> 1241             return self.read(nbytes, buffer)
   1242         else:
   1243             return super().recv_into(buffer, nbytes, flags)

~\anaconda3\lib\ssl.py in read(self, len, buffer)
   1097         try:
   1098             if buffer is not None:
-> 1099                 return self._sslobj.read(len, buffer)
   1100             else:
   1101                 return self._sslobj.read(len)

ConnectionResetError: [WinError 10054] An existing connection was forcibly closed by the remote host

Any possible solutions? What if I download this model via browser and put it into torch cache folder, will it work?

Questions about the Relation to SwAV

Hi, @mathildecaron31. I have a question about a detail in the paper. In Appendix B, Relation to SwAV (Table 14), the paper did an ablative study in terms of composing parts in terms of Momentum encoder and Extra operation. I'm wondering Whether softmax over channel dimension is appended for experiment except Centering (i.e., experiment 2,3,5,6). Since DINO uses an extra channel-wise softmax after Centering to compute the loss, while in SwAV, outputs after batch-level softmax/SK algorithm are used to compute CE.

Issue with `dino_resnet50`

torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
Using cache found in /Users/thomas/.cache/torch/hub/facebookresearch_dino_main
Traceback (most recent call last):
File "", line 1, in
File "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.7/site-packages/torch/hub.py", line 339, in load
model = _load_local(repo_or_dir, model, *args, **kwargs)
File "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.7/site-packages/torch/hub.py", line 368, in _load_local
model = entry(*args, **kwargs)
File "/Users/thomas/.cache/torch/hub/facebookresearch_dino_main/hubconf.py", line 82, in dino_resnet50
model.load_state_dict(state_dict, strict=True)
File "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "fc.weight", "fc.bias".

The patch size is hard coded in hubconf.py

For example:

dino/hubconf.py

Lines 55 to 61 in 8aa93fd

def dino_vitb8(pretrained=True, **kwargs):
"""
ViT-Base/8x8 pre-trained with DINO.
Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
"""
model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs)
if pretrained:

Is that intended (i.e. the model will not work if patch_size=16)? If not I'd be happy to contribute a PR to fix that.

Advice about the evaluation script on Video Object Segmentation

First, Thanks for your excellent work on self-supervised ViTs. It is very impressing to see the visulaizations of the attentions. Just a little advice about the evaluation script on Video Object Segmentation. I noticed that the script computes (1+n_last_frames) source frame features everytime when threre is a new target frame, which is unnecessary. In fact we can just put (1+n_last_frames) source frame features into the Queue and instead just compute target frame(current frame to propagate)'s feature, which will save much computation and thus making the evaluation much faster.
Of course this is not the focus of this study. Just a little advice. _

Nan loss

I am using the following dataset. It is a subset of ImageNet.
https://github.com/rmccorm4/Tiny-Imagenet-200

After training one step, the loss just becomes nan and stops training. Have you experienced this problem? And how do you solve it?

Best regards!

What is outputted by the pretrained model?

I ran

deits16 = torch.hub.load('facebookresearch/dino:main', 'dino_deits16')

to retrieve the pretrained model. This takes an image and outputs a vector of length 384. What is this vector? Is this a representation of the image? And if it is, can I use this pretrained network to create representations of images/patches that I can use for clustering?

Secondly,

in the visualize_attention.py file from line 199 - 206 we save images for all attention heads in a for loop. How are the images in the paper generated then as those are single images? Are they a combination of the heatmaps outputted by the model? And if they are, how are they combined? By average or by summation?

Rationale behind DeiT-S/8 being better than ViT-B/8 at k-NN?

Here is a quote from this comment #8 (comment):

As a matter of fact, on copy detection datasets, I've found the base models to perform clearly better than the small ones: I get better performance with Base16x16 than with Small8x8 though Small8x8 is better at k-NN ImNet.

I assume this is about Table 2 from the article.

Table 2

We see that for both of the tasks (Linear and k-NN):

  • DeiT-S/16 is worse than ViT-B/16,
  • the /8 architectures are better than the /16 architectures.

However, when it comes to comparing the /8 architectures:

  • DeiT-S/8 is worse than ViT-B/8 at the Linear task, as expected,
  • DeiT-S/8 is better than ViT-B/8 at the k-NN task, which is intriguing.

What is the rationale behind DeiT-S/8 being better than ViT-B/8 at the k-NN ImNet task?

Onnx pretrained

Your work looks very interesting.
I'm not familiar with Pytorch / Python and it would be great if the pre-trained nets could be provided in ONNX format.

Regards Armin

copy detection

@mathildecaron31 I have a question about copy detection. I am trying to evaluate the pretrained DINO models on a dataset for copy detection task and I am trying to follow the steps from the paper. Even with different image input sizes in Table 4 we see that final embedding dimension is 1536. I am not able to understand how we can get same embedding dimension after concatenating CLS embedding and GeM pooled output patch tokens for different input image sizes. Maybe I am missing a point here. Here is what I did:

Added the following method to VisionTransformer to return output patch tokens and cls output.

def forward_output_patch_tokens_cls(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
        if self.norm is not None:
            x = self.norm(x)

        return x

Using GeM module from here

def gem(x, p=3, eps=1e-6):
    "x: BS x num tokens x embed_dim"
    return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p)
    
class GeM(nn.Module):

    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

Collect embeddings (CLS + GeM Pooled Output Patch Tokens)

all_image_features = []
with torch.no_grad():
    for imgb in progress_bar(image_dl):
        outputs = model.forward_output_patch_tokens_cls(imgb.cuda())
        cls_token, output_patch_tokens = outputs[:,0],outputs[:,1:]
        
        cls_features   = cls_token   
        patch_features = gem_pooling(output_patch_tokens.permute(0,2,1)).squeeze(-1)
        concat_features = torch.cat([cls_features,patch_features],dim=-1)
        all_image_features.append(concat_features.cpu())

Following this and using an image size of 224 for dino_vitb8 my final embedding dimension is 1568 1536. Which can also be calculated as:

cls_feature_dim*2 = 768*2

Question
Also, during copy detection task do you learn the pooling parameter p or is it picked based on validation set? I didn't quite understand the whitening part is it same as regular unsupervised PCA?

Found this paper: https://hal.inria.fr/hal-00722622v2/document. I believe idea is coming from here.

Edit:

Figured out the 1536 dimension size. We need to pool across token positions, so this gives pooled embedding with same dimension as cls token embedding dimension.

Originally posted by @KeremTurgutlu in #8 (comment)

Regarding the segmentation results

It's a so impressive work and thanks for your code. Regarding the results in the table of Figure 4, I want to ask how to generate the multi-label segmentation maps, ie, how to associate the self-attention maps to different classes.

Different resolutions for global and local crops

Hi @mathildecaron31! Thanks for the great package and concise codebase :)

The global crops haver resolution 224 x 224 and local are 96 x 96:

transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
.

Was it important that the local ones are kept 96 x 96?

I'm building an extension and the code of course becomes simpler if all views are 224 x 224 (since you can collate everything together and need not separate forward passes per resolution), but I wonder if you thing that it would impact the accuracy (besides slowed computation time).

Thanks!

How to use "visualize_attention.py" to process all images in folder?

Hi, I wonder that how to use "python visualize_attention.py" to process all the images in my folder.

I've tried to use "python visualize_attention.py --output_dir /dino/out --image_path ... " but it can only process 1 image.

Because I want to use the processed images to multiply with the original images then get the background-removed data.

If there is any way, please tell me. Thanks!

Smaller model implementation

Hello,

Do you have plans to check the performance of smaller models like mobilenet_v2 or v3 (they are available in torchvision_models)?
If no, I may look into this task. Do you think that small CNN like mobilenet (with depth 1 or even 0.35) is capable of learning these representations?

result on coco datasets

This is an excellent work.
About the downstream task CoCo detection:
I want to know if we have the result(map) on coco dataset.
Thanks.

license

would it be ok to make a web demo for the attention map visualization part on https://gradio.app/ under the current license?

colab

can a colab for inference be added please

`interpolate_pos_encoding(x, pos_embed)` doesnt return correct dimension for images that is not square (w != h)

I notice the generation of positional embedding in interpolate_pos_encoding method is slightly different than the one in the forward_selfattention method. The following simple modification bring both into the same page, to your interest.

    def interpolate_pos_encoding(self, x, pos_embed, w, h):  # passing w and h as arguments
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_embed.patch_size  # just copy paste from forward_selfattention
        h0 = h // self.patch_embed.patch_size
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),  # replace math.sqrt(npatch / N) with one from forward_selfattention
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)

RuntimeError: Given groups=1, weight of size [384, 3, 8, 8], expected input[1, 0, 512, 585] to have 3 channels, but got 0 channels instead

Hi all, I have problem running the inference script, https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422

The images I use are originally grayscale and have been converted to 3 channels. Now the shape of my images are (427, 488, 3)

When I run predict_video(args) I get the following error. Could you please help me here? thanks

I converted the images to 3 channels with the following code.

import cv2
import numpy as np
def saveAsRGB(image_path):
  for filepath in sorted(glob.iglob(image_path + "/*")):
    img = cv2.imread(filepath)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img2 = np.zeros_like(img)
    img2[:,:,0] = gray
    img2[:,:,1] = gray
    img2[:,:,2] = gray
    cv2.imwrite(filepath, img2)

 img = cv2.imread ("~input/finalImg_242.png")
 print(img.shape) #output : (427, 488, 3)

Also as you could see from the below code, inside predict_video(args) function, I tried to print the shape of the image and confirmed that it has 3 channels ([3, 512, 585]).

import numpy as np
import cv2
import glob
def predict_video(args):

    for filepath in sorted(glob.iglob(args.image_path + "/*")):
        print("filepath",filepath)
        img = Image.open(filepath)
        img = img.convert('RGB')
        transform = pth_transforms.Compose([
            pth_transforms.ToTensor(),
            pth_transforms.Resize(512),
            pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        img = transform(img)

        print(".img.shape",img.shape)  # output is torch.Size([3, 512, 585])

Error:

RuntimeError Traceback (most recent call last)
in ()
----> 1 predict_video(args)

6 frames
in predict_video(args)
41
42
---> 43 attentions = model.forward_selfattention(img.cuda())
44
45 nh = attentions.shape[1] # number of head

~/vision_transformer.py in forward_selfattention(self, x)
220 B, nc, w, h = x.shape
221 N = self.pos_embed.shape[1] - 1
--> 222 x = self.patch_embed(x)
223
224 # interpolate patch embeddings

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),

~/vision_transformer.py in forward(self, x)
116 def forward(self, x):
117 B, C, H, W = x.shape
--> 118 x = self.proj(x).flatten(2).transpose(1, 2)
119 return x
120

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in forward(self, input)
397
398 def forward(self, input: Tensor) -> Tensor:
--> 399 return self._conv_forward(input, self.weight, self.bias)
400
401 class Conv3d(_ConvNd):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
394 _pair(0), self.dilation, self.groups)
395 return F.conv2d(input, weight, bias, self.stride,
--> 396 self.padding, self.dilation, self.groups)
397
398 def forward(self, input: Tensor) -> Tensor:

RuntimeError: Given groups=1, weight of size [384, 3, 8, 8], expected input[1, 0, 512, 585] to have 3 channels, but got 0 channels instead

Evaluation time

Wondering how long does it take to run the linear_eval with the default setting?

model collapse after a few steps

I use custom data to train DINO, the model seems collapsed after a few steps, the feature seems to be uniform. I use larger teacher temputure to enhance "sharping", but the model collapsed after all. I wonder if DINO is sensitive to the data, in other word, does DINO tend to collapse when training at differnet data?

How the patch size and input size are set with the Davis mask propagation experiment ?

Hi,
Thanks for sharing your implementation of dino.

We are trying to figure out how do you set the size of the patches when evaluating on Davis.
You said in the paper that the input resolution is 480 for Davis and images are not square.
Or, the VIT model seems to be pre-trained with an input resolution fo 224 and patch size of 16.
So there is a kind of sequence length mismatch between the pre-training and the test.

Can you give a more detail on this ?

Centering and Sharpening balance

Hello, thanks for sharing your great work with codes!

I was wondering if you experimented with no sharpening and no centering at all? I was thinking if using either one causes some collapse, why not use none of them during training.

Thanks,
Jaejin Cho

Can not produce the transfer learning result 82.8% on ImageNet. Ask for fine-tuning code and hyper-parameters for transfer learning on ImageNet.

Thanks for your wonderful work!

I notice the amazing transfer learning result in Appendix A.
Can I direct use the code and parameters configuration in DEIT to do fine-tuning and then get the 82.8 top-1 performance on ImageNet?
Is it necessary to change the hyper-parameters of fine-tuning or change the way of data aug? I noticed that DINO and DEIT use different data aug methods.

I use the code and hyper-parameters in DEIT to fine-tuning DINO-base/16, but only got 81.34% top-1 acc on ImageNet. Which is even worse than the from scratch setting 81.8% top-1 .

Head Layers, Attention Layer

Hi,

I looked at the exported onnx models with https://netron.app/ but I didn't understand where I can find
the "heads from the last layer of a DeiT-S/8 trained with DINO and display the self-attention for [CLS] token query"
which are mentioned in the description of figure 3., so that I can to the same cool segmentation as showed
in the paper.

Regards Armin

KNN evaluation failing for CIFAR10 on ImageNet pre-trained model.

Dear Authors,

Thank you for this amazing work and repository. I wanted to see the clustering capability of the pretrained network on imagenet on CIFAR-10 dataset. For doing so I wanted to use the features you extract in the eval_knn.py script. Below are the parameters I used to run the code:

arch= deit_small (i.e vit_small after the new commit)
patch_size=16
batch_size_per_gpu = 32
data_path = 'path_to_cifar10_dataset'
and all other params are set to their default values.

I left the pretrained_weights as blank so that the model loads the weights of imagenet from the url mentioned in utils.py

The script executed successfully with the message:

Data loaded with 50000 train and 10000 val imgs.
Model deit_small 16x16 built.
Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Pretrained weights found at dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth and loaded with msg: <All keys matched successfully>
Extracting features for train set...
Storing features into tensor of shape torch.Size([50000, 384])

However, I'm facing a strange issue, the features extracted are coming out to be a zero vector, for all of the input image.

torch.Size([50000, 384])
torch.Size([50000])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

Can you please specify where I might be going wrong.
Thanks in advance!

Video Instance Segmentation task

In this task, how do you propagate the label from previous frames?

Specifically,

  • What kind of distance do you use in the kNN (cosine or L2)?
  • How many blocks did you use as the feature of a token?

Error in interpolate_pos_encoding method

I found an error that happens in the interpolation method.

It looks like that this method only work with square images.

To fix this, we need different scale factor as in the forward_selfattention method.

Single gpu training

Hi,

I'm just wondering, is there a way to train this on a single GPU without distributed launch?

Best,
Jason

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.