adaptivetokensampling / ats Goto Github PK
View Code? Open in Web Editor NEWAdaptive Token Sampling for Efficient Vision Transformers (ECCV 2022 Oral Presentation)
Home Page: https://adaptivetokensampling.github.io/
License: Apache License 2.0
Adaptive Token Sampling for Efficient Vision Transformers (ECCV 2022 Oral Presentation)
Home Page: https://adaptivetokensampling.github.io/
License: Apache License 2.0
Hi, i am interested in this project. When i run the code, the error comes like following:
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:111: operator(): block: [310285,0,0], thread: [127,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
terminate called after throwing an instance of 'c10::CUDAError'
what(): CUDA error: device-side assert triggered
Exception raised from createEvent at ../aten/src/ATen/cuda/CUDAEvent.h:174 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f0f18be27d2 in
../miniconda3/envs/ATS/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: + 0x267df7a (0x7f0f6bc01f7a in ../miniconda3/envs/ATS/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so)
frame #2: + 0x300568 (0x7f0fcdffa568 in ../miniconda3/envs/ATS/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #3: c10::TensorImpl::release_resources() + 0x175 (0x7f0f18bcb005 in ../miniconda3/envs/ATS/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #4: + 0x1ee569 (0x7f0fcdee8569 in ../miniconda3/envs/ATS/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #5: + 0x4d9c78 (0x7f0fce1d3c78 in ../miniconda3/envs/ATS/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #6: THPVariable_subclass_dealloc(_object*) + 0x292 (0x7f0fce1d3f72 in ../miniconda3/envs/ATS/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #34: __libc_start_main + 0xe7 (0x7f0fd0197c87 in /lib/x86_64-linux-gnu/libc.so.6)
Is there anything wrong? I use the code and change the realtive file dir without changing other anything.
Hi, may I ask why is the ATS differentiable?
In my understanding, because the CDF function (equation (4) in the paper) is piecewise constant, the inverse of CDF (equation (5) in the paper) is also piecewise constant and thus is not differentiable. Did I miss something?
Thank you in advance!
why # from .cuda import gather_tokens, scatter_tokens ? This causes an error
In the file libs/config/defaults.py
there's an indication of ImageNet as training dataset and kinetics as test dataset, why is that?
If we want to train on a different dataset say CIFAR100, would it suffice to just change the data options in defaults.py
?
Also, there seems to be some some PCA eigenvalues (I suppose from the dataset?) in the defaults.py
, how are those gonna affect the model when training on a different dataset other than ImageNet?
Your adaptive token sampling attracts me a lot.
I am curious how you calculate the FLOPs of DeiT-ATS in Table 1, since the number of selected tokens varies for all images as shown in Figure 6 in paper. Do you calculate the mean of the number of selected tokens for all images?
Thanks~
Hello,
I am currently trying to reproduce the results given in Fig 5 (b) and (c), in "not finetuned" mode.
Here is my conf for the GFLOPs level of 3, Stage 3 not finetuned:
TRAIN:
ENABLE: False
TEST:
ENABLE: True
DATASET: ImageNet
BATCH_SIZE: 1024
CHECKPOINT_FILE_PATH: "/root/workspace/projects/ATS/models/deit_small_patch16_224-cd65a155.pth"
NUM_ENSEMBLE_VIEWS: 1
NUM_SPATIAL_CROPS: 1
SAVE_RESULTS_PATH: "/root/no_backup/preds_ats.pkl"
DATA:
PATH_TO_DATA_DIR: "/datasets_local/ImageNet/"
TEST_CROP_SIZE: 224
TRAIN_CROP_SIZE: 224
MEAN: [0.485, 0.456, 0.406]
STD: [0.229, 0.224, 0.225]
DATA_LOADER:
NUM_WORKERS: 2
VIT:
IMG_SIZE: 224
PATCH_SIZE: 16
IN_CHANNELS: 3
NUM_CLASSES: 1000
EMBED_DIM: 384
DEPTH: 12
NUM_HEADS: 6
MLP_RATIO: 4.0
QKV_BIAS: True
QK_SCALE: None
REPRESENTATION_SIZE: None
DROP_RATE: 0.0
ATTN_DROP_RATE: 0.0
DROP_PATH_RATE: 0.0
HYBRID_BACKBONE: None
NORM_LAYER: None
ATS_BLOCKS: [3]
NUM_TOKENS: [108, 108, 108, 108, 108, 108, 108, 108, 108, 108, 108, 108]
DROP_TOKENS: True
NUM_GPUS: 1
And here is my conf for the GFLOPs level of 3, Multi-stage not finetuned:
TRAIN:
ENABLE: False
TEST:
ENABLE: True
DATASET: ImageNet
BATCH_SIZE: 1024
CHECKPOINT_FILE_PATH: "/root/workspace/projects/ATS/models/deit_small_patch16_224-cd65a155.pth"
NUM_ENSEMBLE_VIEWS: 1
NUM_SPATIAL_CROPS: 1
SAVE_RESULTS_PATH: "/root/no_backup/preds_ats.pkl"
DATA:
PATH_TO_DATA_DIR: "/datasets_local/ImageNet/"
TEST_CROP_SIZE: 224
TRAIN_CROP_SIZE: 224
MEAN: [0.485, 0.456, 0.406]
STD: [0.229, 0.224, 0.225]
DATA_LOADER:
NUM_WORKERS: 2
VIT:
IMG_SIZE: 224
PATCH_SIZE: 16
IN_CHANNELS: 3
NUM_CLASSES: 1000
EMBED_DIM: 384
DEPTH: 12
NUM_HEADS: 6
MLP_RATIO: 4.0
QKV_BIAS: True
QK_SCALE: None
REPRESENTATION_SIZE: None
DROP_RATE: 0.0
ATTN_DROP_RATE: 0.0
DROP_PATH_RATE: 0.0
HYBRID_BACKBONE: None
NORM_LAYER: None
ATS_BLOCKS: [3, 4, 5, 6, 7, 8, 9, 10, 11]
NUM_TOKENS: [108, 108, 108, 108, 108, 108, 108, 108, 108, 108, 108, 108]
DROP_TOKENS: True
NUM_GPUS: 1
However I am not able to reach the Top1-Accuracy you indicate in these figures. Could you please provide the config files leading to the creation of Fig5 (b) and (c) please ?
Thank you in advance !
If there is a guide on this topic, please tell me(The paper talked about it, while the repo has few hints). Starting from scratch by reading the full source code is painful and time-consuming. Thanks a lot. :)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.