Giter VIP home page Giter VIP logo

probing-vits's Introduction

Probing ViTs

TensorFlow 2.8 HugginFace badge

By Aritra Roy Gosthipaty and Sayak Paul (equal contribution)

In this repository, we provide tools to probe into the representations learned by different families of Vision Transformers (supervised pre-training with ImageNet-21k, ImageNet-1k, distillation, self-supervised pre-training):

  • Original ViT [1]
  • DeiT [2]
  • DINO [3]

We hope these tools will prove to be useful for the community. Please follow along with this post on keras.io for a better navigation through the repository.

Updates

Self-attention visualization

Original Image Attention Maps Attention Maps Overlayed
original image attention maps attention maps overlay
output-dino.mp4

Original Video Source

output-dog.mp4

Original Video Source

Supervised salient representations

In the DINO blog post, the authors show a video with the following caption:

The original video is shown on the left. In the middle is a segmentation example generated by a supervised model, and on the right is one generated by DINO.

A screenshot of the video is as follows:

image

We obtain the attention maps generated with the supervised pre-trained model and find that they are not that salient w.r.t the DINO model. We observe a similar behaviour in our experiments as well. The figure below shows the attention heatmaps extracted with a ViT-B16 model pre-trained (supervised) using ImageNet-1k:

Dinosaur Dog

We used this Colab Notebook to conduct this experiment.

Hugging Face Spaces

You can now probe into the ViTs with your own input images.

Attention Heat Maps Attention Rollout
Generic badge Generic badge

Visualizing mean attention distances

Methods

We don't propose any novel methods of probing the representations of neural networks. Instead we take the existing works and implement them in TensorFlow.

  • Mean attention distance [1, 4]
  • Attention Rollout [5]
  • Visualization of the learned projection filters [1]
  • Visualization of the learned positioanl embeddings
  • Attention maps from individual attention heads [3]
  • Generation of attention heatmaps from videos [3]

Another interesting repository that also visualizes ViTs in PyTorch: https://github.com/jacobgil/vit-explain.

Notes

We first implemented the above-mentioned architectures in TensorFlow and then we populated the pre-trained parameters into them using the official codebases. In order to validate this, we evaluated the implementations on the ImageNet-1k validation set and ensured that the reported top-1 accuracies matched.

We value the spirit of open-source. So, if you spot any bugs in the code or see a scope for improvement don't hesitate to open up an issue or contribute a PR. We'd very much appreciate it.

Navigating through the codebase

Our ViT implementations are in vit. We provide utility notebooks in the notebooks directory which contains the following:

DeiT-related code has its separate repository: https://github.com/sayakpaul/deit-tf.

Models

Here are the links to the models where the pre-trained parameters were populated:

Training and visualizing with small datasets

Coming soon!

References

[1] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale: https://arxiv.org/abs/2010.11929

[2] DeiT: https://arxiv.org/abs/2012.12877

[3] DINO: https://arxiv.org/abs/2104.14294

[4] Do Vision Transformers See Like Convolutional Neural Networks?: https://arxiv.org/abs/2108.08810

[5] Quantifying Attention Flow in Transformers: https://arxiv.org/abs/2005.00928

Acknowledgements

probing-vits's People

Contributors

antoinetoubhans avatar arig23498 avatar sayakpaul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

probing-vits's Issues

Dino Base models in saved model format

how can the dino base model checkpoints ( Feature Extractor) in pytorch or onnx be converted into Saved Model Format ( Tensorflow)? which will help in finetuning the dino model

Supervised training doesn't help that much for extracting salient representations it seems?

In the DINO blog post, the authors show the following:

image

This is what they say in the video caption:

The original video is shown on the left. In the middle is a segmentation example generated by a supervised model, and on the right is one generated by DINO. (All examples are licensed from Stock.)

We see that the attention maps generated with the supervised pre-trained model aren't that salient w.r.t the DINO model.

Seems to be verified:

Here's the Colab Notebook that verified it. The notebook is not formatted (be aware).

Add `training=False` when converting DINO weights

Hey @sayakpaul ,
Thanks for the notebook and the awesome work here, very helpful :)

I think there is a small typo in the "load-dino-weights" notebook:

In the "Validating the initial architecture" cell, when calling vit_dino_base, you should pass training=False argument:

vit_dino_base = ViTDINOBase(config)

dummy_inputs = tf.random.normal((2, 224, 224, 3))
outputs, attn_scores = vit_dino_base(dummy_inputs, training=False)  # I added the training=False args

Otherwise everything works fine :)
Let me know if you want me to open a PR to fix this minor issue.
Cheers,
Antoine

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.