Giter VIP home page Giter VIP logo

vim's Introduction

Vision Mamba

Efficient Visual Representation Learning with Bidirectional State Space Model

Lianghui Zhu1 *,Bencheng Liao1 *,Qian Zhang2, Xinlong Wang3, Wenyu Liu1, Xinggang Wang1 📧

1 Huazhong University of Science and Technology, 2 Horizon Robotics, 3 Beijing Academy of Artificial Intelligence

(*) equal contribution, (📧) corresponding author.

ArXiv Preprint (arXiv 2401.09417), HuggingFace Page (🤗 2401.09417)

News

  • Feb. 10th, 2024: We update Vim-tiny/small weights and training scripts. By placing the class token at middle, Vim achieves improved results. Further details can be found in code and our updated arXiv.

  • Jan. 18th, 2024: We released our paper on Arxiv. Code/Models are coming soon. Please stay tuned! ☕️

Abstract

Recently the state space models (SSMs) with efficient hardware-aware designs, i.e., the Mamba deep learning model, have shown great potential for long sequence modeling. Meanwhile building efficient and generic vision backbones purely upon SSMs is an appealing direction. However, representing visual data is challenging for SSMs due to the position-sensitivity of visual data and the requirement of global context for visual understanding. In this paper, we show that the reliance on self-attention for visual representation learning is not necessary and propose a new generic vision backbone with bidirectional Mamba blocks (Vim), which marks the image sequences with position embeddings and compresses the visual representation with bidirectional state space models. On ImageNet classification, COCO object detection, and ADE20k semantic segmentation tasks, Vim achieves higher performance compared to well-established vision transformers like DeiT, while also demonstrating significantly improved computation & memory efficiency. For example, Vim is 2.8x faster than DeiT and saves 86.8% GPU memory when performing batch inference to extract features on images with a resolution of 1248x1248. The results demonstrate that Vim is capable of overcoming the computation & memory constraints on performing Transformer-style understanding for high-resolution images and it has great potential to be the next-generation backbone for vision foundation models.

Overview

Envs. for Pretraining

  • Python 3.10.13

    • conda create -n your_env_name python=3.10.13
  • torch 2.1.1 + cu118

    • pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
  • Requirements: vim_requirements.txt

    • pip install -r vim/vim_requirements.txt
  • Install causal_conv1d and mamba

    • pip install -e causal_conv1d>=1.1.0
    • pip install -e mamba-1p1p1

Train Your Vim

bash vim/scripts/pt-vim-t.sh

Train Your Vim at Finer Granularity

bash vim/scripts/ft-vim-t.sh

Model Weights

Model #param. Top-1 Acc. Top-5 Acc. Hugginface Repo
Vim-tiny 7M 76.1 93.0 https://huggingface.co/hustvl/Vim-tiny-midclstok
Vim-tiny+ 7M 78.3 94.2 https://huggingface.co/hustvl/Vim-tiny-midclstok
Vim-small 26M 80.5 95.1 https://huggingface.co/hustvl/Vim-small-midclstok
Vim-small+ 26M 81.6 95.4 https://huggingface.co/hustvl/Vim-small-midclstok

Notes:

  • + means that we finetune at finer granularity with short schedule.

Evaluation on Provided Weights

To evaluate Vim-Ti on ImageNet-1K, run:

python main.py --eval --resume /path/to/ckpt --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --data-path /path/to/imagenet

Acknowledgement ❤️

This project is based on Mamba (paper, code), Causal-Conv1d (code), DeiT (paper, code). Thanks for their wonderful works.

Citation

If you find Vim is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.

 @article{vim,
  title={Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model},
  author={Lianghui Zhu and Bencheng Liao and Qian Zhang and Xinlong Wang and Wenyu Liu and Xinggang Wang},
  journal={arXiv preprint arXiv:2401.09417},
  year={2024}
}

vim's People

Contributors

jingfengyao avatar legendbc avatar unrealluver avatar xinggangw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

vim's Issues

How to test pretrained model on CIFAR dataset and coco2017

hi,

Since the pretrained model's head is 1000 classes, how can I test it on the CIFAR dataset which only has 100 classes?

For the coco2017 dataset, could you provide the data preprocessing code about how to construct it properly for evaluation?
here is my data folder structure for coco dataset:
-dataset
-coco
-annotations
-images
-test
-train

Issue with Loading Models from Huggingface Repo

Hi,
Thank you for your fantastic project.

I've encountered a problem when trying to evaluate the model downloaded from the Huggingface repo given in your project. I'm facing an issue with both the small and tiny vim models. When I loaded the models, I receive the following error message:
"size mismatch for pos_embed: copying a param with shape torch.Size([1, 730, 384]) from checkpoint, the shape in the current model is torch.Size([1, 197, 384])."

Could you please offer some guidance on how to address this problem? Your assistance would be greatly appreciated.

Thank you so much for this matter!

No grad accumulator for a saved leaf

I get an error. ../torch/csrc/autograd/saved_variable.cpp":216, please report a bug to PyTorch. No grad accumulator for a saved leaf.

It's thrown at mamba_ssm\ops\selective_scan_interface.py backward(ctx, dout)

(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, 
         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors

image

Could you please give me some hints? Thank you.

Incompatible to CU118 when installing causal-conv1d and mamba

Thanks a lot to the authors for the wonderful work. I have some quick questions about the repo.

When installing causal-conv1d and mamba, I encountered an error saying "The detected CUDA version (12.3) mismatches the version that was used to compile PyTorch (11.8). Please make sure to use the same CUDA versions.". Do you know how to address it?

I also tried several different torch and cuda versions (eg tc211+cu121), it seems that sometimes it can be successfully installed but the model training become quite slow. Did you find similar issues about that, or does this repo have a strict constraint of the torch and cuda versions?

Hidden State understanding

Hi @xinggangw @Unrealluver @ifzhang @LegendBC ,
Thanks for the release of the great work. But, I was trying to understand how I can actually interpret the model such that the results will be more explainable. Using the pretrained image classification weights, I just want to understand whether during training or inference, how can I record the attention weights generated within the BSSM at each time step. I have gone through the mamba_model.py where we have hidden states and residuals. Can I directly use them for understanding of the model predictions or do I need to extract other details using main.py or Mamba_model.py?If yes, what are the different directions that we can look into like saliency maps etc ,. Any suggestions would be highly useful. Thanks in advance.

License

Hi, thanks for releasing! Any plans to add a license? Thx!!

Hidden State understanding

Hi @xinggangw @Unrealluver @ifzhang @LegendBC ,
Thanks for the release of the great work. But, I was trying to understand how I can actually interpret the model such that the results will be more explainable. Using the pretrained image classification weights, I just want to understand whether during training or inference, how can I record the attention weights generated within the BSSM at each time step. I have gone through the mamba_model.py where we have hidden states and residuals. Can I directly use them for understanding of the model predictions or do I need to extract other details using main.py or Mamba_model.py?If yes, what are the different directions that we can look into like saliency maps etc ,. Any suggestions would be highly useful. Thanks in advance.

zope-interface>=5

When I run "python setup.py install" in the mamba-1p1p1 folder, I get the following error:
Processing dependencies for mamba-ssm==1.1.1
Searching for zope-interface>=5
Reading https://pypi.org/simple/zope-interface/
No local packages or working download links found for zope-interface>=5
error: Could not find suitable distribution for Requirement.parse('zope-interface>=5')

However, I have installed zope-interface==5 or zope-interface==6.2 and it still gives the same error.
What is the reason for this? How can I fix it?

Upload code for segmentor head

Hi,

Could you please upload code for the segmentor head? i dont see them in the codebase, while the paper clearly showed segmentation performance.

Thanks,
Michael

Cannot use pretrained model

Hello,

I am running this command on an Ubuntu 22.04 LTS system python main.py --finetune '' --eval --resume vim_t_midclstok_76p1acc.pth--model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --data-path imagenet-mini --data-set 'IMNET' in order to try your model, as you suggest in the documentation.

I installed mamba-ssm using pip wheels because your recommended approach does not work, as noted in another issue.

However, it seems that the pretrained model is not loaded correctly as shown in the error message below. What could I do to fix this?

Traceback (most recent call last):
File "/rds/general/user/kp4718/home/code/Vim/vim/main.py", line 545, in
main(args)
File "/rds/general/user/kp4718/home/code/Vim/vim/main.py", line 448, in main
model_without_ddp.load_state_dict(checkpoint['model'])
File "/rds/general/user/kp4718/home/anaconda3/envs/mambaenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VisionMamba:
Unexpected key(s) in state_dict: "layers.0.mixer.A_b_log", "layers.0.mixer.D_b", "layers.0.mixer.conv1d_b.weight", "layers.0.mixer.conv1d_b.bias", "layers.0.mixer.x_proj_b.weight", "layers.0.mixer.dt_proj_b.weight", "layers.0.mixer.dt_proj_b.bias", "layers.1.mixer.A_b_log", "layers.1.mixer.D_b", "layers.1.mixer.conv1d_b.weight", "layers.1.mixer.conv1d_b.bias", "layers.1.mixer.x_proj_b.weight", "layers.1.mixer.dt_proj_b.weight", "layers.1.mixer.dt_proj_b.bias", "layers.2.mixer.A_b_log", "layers.2.mixer.D_b", "layers.2.mixer.conv1d_b.weight", "layers.2.mixer.conv1d_b.bias", "layers.2.mixer.x_proj_b.weight", "layers.2.mixer.dt_proj_b.weight", "layers.2.mixer.dt_proj_b.bias", "layers.3.mixer.A_b_log", "layers.3.mixer.D_b", "layers.3.mixer.conv1d_b.weight", "layers.3.mixer.conv1d_b.bias", "layers.3.mixer.x_proj_b.weight", "layers.3.mixer.dt_proj_b.weight", "layers.3.mixer.dt_proj_b.bias", "layers.4.mixer.A_b_log", "layers.4.mixer.D_b", "layers.4.mixer.conv1d_b.weight", "layers.4.mixer.conv1d_b.bias", "layers.4.mixer.x_proj_b.weight", "layers.4.mixer.dt_proj_b.weight", "layers.4.mixer.dt_proj_b.bias", "layers.5.mixer.A_b_log", "layers.5.mixer.D_b", "layers.5.mixer.conv1d_b.weight", "layers.5.mixer.conv1d_b.bias", "layers.5.mixer.x_proj_b.weight", "layers.5.mixer.dt_proj_b.weight", "layers.5.mixer.dt_proj_b.bias", "layers.6.mixer.A_b_log", "layers.6.mixer.D_b", "layers.6.mixer.conv1d_b.weight", "layers.6.mixer.conv1d_b.bias", "layers.6.mixer.x_proj_b.weight", "layers.6.mixer.dt_proj_b.weight", "layers.6.mixer.dt_proj_b.bias", "layers.7.mixer.A_b_log", "layers.7.mixer.D_b", "layers.7.mixer.conv1d_b.weight", "layers.7.mixer.conv1d_b.bias", "layers.7.mixer.x_proj_b.weight", "layers.7.mixer.dt_proj_b.weight", "layers.7.mixer.dt_proj_b.bias", "layers.8.mixer.A_b_log", "layers.8.mixer.D_b", "layers.8.mixer.conv1d_b.weight", "layers.8.mixer.conv1d_b.bias", "layers.8.mixer.x_proj_b.weight", "layers.8.mixer.dt_proj_b.weight", "layers.8.mixer.dt_proj_b.bias", "layers.9.mixer.A_b_log", "layers.9.mixer.D_b", "layers.9.mixer.conv1d_b.weight", "layers.9.mixer.conv1d_b.bias", "layers.9.mixer.x_proj_b.weight", "layers.9.mixer.dt_proj_b.weight", "layers.9.mixer.dt_proj_b.bias", "layers.10.mixer.A_b_log", "layers.10.mixer.D_b", "layers.10.mixer.conv1d_b.weight", "layers.10.mixer.conv1d_b.bias", "layers.10.mixer.x_proj_b.weight", "layers.10.mixer.dt_proj_b.weight", "layers.10.mixer.dt_proj_b.bias", "layers.11.mixer.A_b_log", "layers.11.mixer.D_b", "layers.11.mixer.conv1d_b.weight", "layers.11.mixer.conv1d_b.bias", "layers.11.mixer.x_proj_b.weight", "layers.11.mixer.dt_proj_b.weight", "layers.11.mixer.dt_proj_b.bias", "layers.12.mixer.A_b_log", "layers.12.mixer.D_b", "layers.12.mixer.conv1d_b.weight", "layers.12.mixer.conv1d_b.bias", "layers.12.mixer.x_proj_b.weight", "layers.12.mixer.dt_proj_b.weight", "layers.12.mixer.dt_proj_b.bias", "layers.13.mixer.A_b_log", "layers.13.mixer.D_b", "layers.13.mixer.conv1d_b.weight", "layers.13.mixer.conv1d_b.bias", "layers.13.mixer.x_proj_b.weight", "layers.13.mixer.dt_proj_b.weight", "layers.13.mixer.dt_proj_b.bias", "layers.14.mixer.A_b_log", "layers.14.mixer.D_b", "layers.14.mixer.conv1d_b.weight", "layers.14.mixer.conv1d_b.bias", "layers.14.mixer.x_proj_b.weight", "layers.14.mixer.dt_proj_b.weight", "layers.14.mixer.dt_proj_b.bias", "layers.15.mixer.A_b_log", "layers.15.mixer.D_b", "layers.15.mixer.conv1d_b.weight", "layers.15.mixer.conv1d_b.bias", "layers.15.mixer.x_proj_b.weight", "layers.15.mixer.dt_proj_b.weight", "layers.15.mixer.dt_proj_b.bias", "layers.16.mixer.A_b_log", "layers.16.mixer.D_b", "layers.16.mixer.conv1d_b.weight", "layers.16.mixer.conv1d_b.bias", "layers.16.mixer.x_proj_b.weight", "layers.16.mixer.dt_proj_b.weight", "layers.16.mixer.dt_proj_b.bias", "layers.17.mixer.A_b_log", "layers.17.mixer.D_b", "layers.17.mixer.conv1d_b.weight", "layers.17.mixer.conv1d_b.bias", "layers.17.mixer.x_proj_b.weight", "layers.17.mixer.dt_proj_b.weight", "layers.17.mixer.dt_proj_b.bias", "layers.18.mixer.A_b_log", "layers.18.mixer.D_b", "layers.18.mixer.conv1d_b.weight", "layers.18.mixer.conv1d_b.bias", "layers.18.mixer.x_proj_b.weight", "layers.18.mixer.dt_proj_b.weight", "layers.18.mixer.dt_proj_b.bias", "layers.19.mixer.A_b_log", "layers.19.mixer.D_b", "layers.19.mixer.conv1d_b.weight", "layers.19.mixer.conv1d_b.bias", "layers.19.mixer.x_proj_b.weight", "layers.19.mixer.dt_proj_b.weight", "layers.19.mixer.dt_proj_b.bias", "layers.20.mixer.A_b_log", "layers.20.mixer.D_b", "layers.20.mixer.conv1d_b.weight", "layers.20.mixer.conv1d_b.bias", "layers.20.mixer.x_proj_b.weight", "layers.20.mixer.dt_proj_b.weight", "layers.20.mixer.dt_proj_b.bias", "layers.21.mixer.A_b_log", "layers.21.mixer.D_b", "layers.21.mixer.conv1d_b.weight", "layers.21.mixer.conv1d_b.bias", "layers.21.mixer.x_proj_b.weight", "layers.21.mixer.dt_proj_b.weight", "layers.21.mixer.dt_proj_b.bias", "layers.22.mixer.A_b_log", "layers.22.mixer.D_b", "layers.22.mixer.conv1d_b.weight", "layers.22.mixer.conv1d_b.bias", "layers.22.mixer.x_proj_b.weight", "layers.22.mixer.dt_proj_b.weight", "layers.22.mixer.dt_proj_b.bias", "layers.23.mixer.A_b_log", "layers.23.mixer.D_b", "layers.23.mixer.conv1d_b.weight", "layers.23.mixer.conv1d_b.bias", "layers.23.mixer.x_proj_b.weight", "layers.23.mixer.dt_proj_b.weight", "layers.23.mixer.dt_proj_b.bias".

Run the test code successfully, share my environment.

Hello every researcher,
I have researched this code for 3 day, Finally, I can run the test code successfully. below is the test code.
`import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
print(torch.cuda.is_available())
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape`

Therefore, I would like to share my experiments to help those people are still struggling for the problem of setting environment.
my syetem:
**Ubuntu 22.04
nvidia driver 535 (if you don't know how to upgrade or downgrade your nvidia-driver, you can leave a message here)
CUDA 11.8 (downloaded run file from official website)

python 3.10.13 (installed via conda )
pytorch 2.2.2 (installed via conda)
causal-conv1d-1.1.3.post1( installed via pip install causal-conv1d==1.1.3.post1)
mamba_ssm(installed via author provided file cd mamba-1p1p1/ pip install -e .)**

Issue with selective_scan_cuda

Hi,

I cannot properly select selective_scan_cuda:
How can the error about the unresolved symbol be solved?

Traceback (most recent call last):
File "/rds/general/user/kp4718/home/code/Vim/vim/main.py", line 28, in
import models_mamba
File "/rds/general/user/kp4718/home/code/Vim/vim/models_mamba.py", line 21, in
from mamba_ssm.modules.mamba_simple import Mamba
File "/rds/general/user/kp4718/home/anaconda3/envs/mambaenv/lib/python3.10/site-packages/mamba_ssm/init.py", line 3, in
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
File "/rds/general/user/kp4718/home/anaconda3/envs/mambaenv/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 11, in
import selective_scan_cuda
ImportError: /rds/general/user/kp4718/home/anaconda3/envs/mambaenv/lib/python3.10/site-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb

Memory/speed improvements over DeiT for larger Vim

I find your paper on Vision Mamba very interesting. However, when using your code, I encountered a problem (which may well be normal behavior). When analyzing GPU memory consumption and FPS for Vim versions other than Tiny, I could not achieve similar speed and memory improvements. I compared it to DeiT, and the improvements were only visible in Vim-Ti. Am I doing something wrong, or are the improvements only in the Tiny version?

Training Recipe for Vim-B (base model)

Hello,

Thank you very much for the insightful work detailed in your paper!

The training recipe (e.g. learning rate, scheduler, #of epochs) and results for the Vim-T and Vim-S models were quite informative.

I am actually interested in adapting this approach for the base model. Could you provide guidance on any modifications needed in the training recipe for this purpose?

Additionally, if you might have results for the base model on ImageNet 1K, I would greatly appreciate if you could share them. This would allow me to benchmark my training outcomes against yours for a comprehensive comparison.

Thank you once again for your valuable contributions!

bimamba version

阅读您的代码时发现bimamba有v1和v2版本分别对应BimambaInnerFn和MambaInnerFnNoOutProj;v1对同一个序列分别进行正向forward和反向forward但是采用了相同的proj模块计算B和C,而v2则是定义了两次proj模块(x_proj于x_proj_b)分别计算不同的B,C,请问为什么要这样做呢?以及这两者哪个效果会更好一点?最后mamba中step()这个函数是不是不管训练还是推理时都不会执行,因为step中并没有指明两次计算的过程?非常感谢。😦

VIM-S

Any plans for uploading the weights for the vim-s model?

Code to train Vim object detection

Hi everyone, first thing first, kudos to all the research team.

Do you plan to release the object detection training code anytime soon? The memory consumption of this architecture is very compelling..

Question about the design of gate Z;

Hello,
I would like to express my appreciation for the outstanding work on this project.

In the original Mamba, there is no existence of the 'z' mechanism as a gate. However, in Vim, an additional 'z' has been incorporated as a gate. Why was it designed this way? What would happen if it were removed?

The code in question can be found at:

y = y * self.act(z) # (B D)

Illustration

Your clarification would be immensely appreciated.

problem about pip install -e causal_conv1d>=1.1.0

I try to install causal_conv1d and use the command:
pip install -e causal_conv1d>=1.1.0
but have this error:
ERROR: causal_conv1d is not a valid editable requirement. It should either be a path to a local project or a VCS URL (beginning with bzr+http, bzr+https, bzr+ssh, bzr+sftp, bzr+ftp, bzr+lp, bzr+file, git+http, git+https, git+ssh, git+git, git+file, hg+file, hg+http, hg+https, hg+ssh, hg+static-http, svn+ssh, svn+http, svn+https, svn+svn, svn+file).

then I use the command :
pip install -e causual-conv1d

but this has another problem:
Installing collected packages: causal-conv1d
Running setup.py develop for causal-conv1d
error: subprocess-exited-with-error

× python setup.py develop did not run successfully.
│ exit code: 1
╰─> [83 lines of output]
    
    
    torch.__version__  = 2.1.1+cu118
    
    
    running develop
    /home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/command/develop.py:40: EasyInstallDeprecationWarning: easy_install command is deprecated.
    !!
    
            ********************************************************************************
            Please avoid running ``setup.py`` and ``easy_install``.
            Instead, use pypa/build, pypa/installer or other
            standards-based tools.
    
            See https://github.com/pypa/setuptools/issues/917 for details.
            ********************************************************************************
    
    !!
      easy_install.initialize_options(self)
    /home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
    !!
    
            ********************************************************************************
            Please avoid running ``setup.py`` directly.
            Instead, use pypa/build, pypa/installer or other
            standards-based tools.
    
            See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
            ********************************************************************************
    
    !!
      self.initialize_options()
    running egg_info
    creating causal_conv1d.egg-info
    writing causal_conv1d.egg-info/PKG-INFO
    writing dependency_links to causal_conv1d.egg-info/dependency_links.txt
    writing requirements to causal_conv1d.egg-info/requires.txt
    writing top-level names to causal_conv1d.egg-info/top_level.txt
    writing manifest file 'causal_conv1d.egg-info/SOURCES.txt'
    reading manifest file 'causal_conv1d.egg-info/SOURCES.txt'
    adding license file 'LICENSE'
    adding license file 'AUTHORS'
    writing manifest file 'causal_conv1d.egg-info/SOURCES.txt'
    running build_ext
    Traceback (most recent call last):
      File "<string>", line 2, in <module>
      File "<pip-setuptools-caller>", line 34, in <module>
      File "/home/wzw/Vim/causal-conv1d/setup.py", line 226, in <module>
        setup(
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/__init__.py", line 103, in setup
        return distutils.core.setup(**attrs)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 185, in setup
        return run_commands(dist)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
        dist.run_commands()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
        self.run_command(cmd)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
        super().run_command(command)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
        cmd_obj.run()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/command/develop.py", line 34, in run
        self.install_for_development()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/command/develop.py", line 109, in install_for_development
        self.run_command('build_ext')
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
        self.distribution.run_command(command)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
        super().run_command(command)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
        cmd_obj.run()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 88, in run
        _build_ext.run(self)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 345, in run
        self.build_extensions()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 525, in build_extensions
        _check_cuda_version(compiler_name, compiler_version)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 413, in _check_cuda_version
        raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda))
    RuntimeError:
    The detected CUDA version (12.2) mismatches the version that was used to compile
    PyTorch (11.8). Please make sure to use the same CUDA versions.
    
    [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.

error: subprocess-exited-with-error

Much lower training efficiency

Thanks for your great works! However, I observe that the training efficiency (including the training speed and memory use) is much lower than that of the plain ViT with a similar mode size. Do you have any insights on this phenomenon?

subprocess.CalledProcessError

/home/ai1015/anaconda3/envs/vim/bin/python /mnt/data/ai1015/Vim/vim/main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2
Not using distributed mode
Namespace(gpu=0, batch_size=2, epochs=300, bce_loss=False, unscale_lr=False, model='vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, sched='cosine', lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, cosub=False, finetune='', attn_only=False, data_path='/mnt/data/ai1015/Vim/data/imagenet/', data_set='IMNET', inat_category='name', output_dir='', device='cuda', seed=0, resume='', start_epoch=0, eval=False, eval_crop_ratio=0.875, dist_eval=False, num_workers=10, pin_mem=True, distributed=False, world_size=1, dist_url='env://', if_amp=True, if_continue_inf=False, if_nan2num=False, if_random_cls_token_position=False, if_random_token_rank=False, local_rank=0)
Creating model: vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2
number of params: 6228874
Start training for 300 epochs
/usr/bin/ld: skipping incompatible /lib/i386-linux-gnu/libcuda.so when searching for -lcuda
/usr/bin/ld: skipping incompatible /lib/i386-linux-gnu/libcuda.so when searching for -lcuda
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
Traceback (most recent call last):
File "/mnt/data/ai1015/Vim/vim/main.py", line 550, in
main(args)
File "/mnt/data/ai1015/Vim/vim/main.py", line 482, in main
train_stats = train_one_epoch(
File "/mnt/data/ai1015/Vim/vim/engine.py", line 54, in train_one_epoch
outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/data/ai1015/Vim/vim/models_mamba.py", line 543, in forward
x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
File "/mnt/data/ai1015/Vim/vim/models_mamba.py", line 480, in forward_features
hidden_states, residual = layer(
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/data/ai1015/Vim/vim/models_mamba.py", line 115, in forward
hidden_states, residual = fused_add_norm_fn(
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 83, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench
fn()
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 81, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "", line 63, in _layer_norm_fwd_1pass_kernel
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/compiler/compiler.py", line 425, in compile
so_path = make_stub(name, signature, constants)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/compiler/make_launcher.py", line 39, in make_stub
so = _build(name, src_path, tmpdir)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/common/build.py", line 90, in _build
ret = subprocess.check_call(cc_cmd)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/subprocess.py", line 369, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpci9keay7/main.c', '-O3', '-I/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/common/../third_party/cuda/include', '-I/home/ai1015/anaconda3/envs/vim/include/python3.10', '-I/tmp/tmpci9keay7', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmpci9keay7/_layer_norm_fwd_1pass_kernel.cpython-310-x86_64-linux-gnu.so', '-L/lib/x86_64-linux-gnu', '-L/lib/i386-linux-gnu', '-L/lib/i386-linux-gnu']' returned non-zero exit status 1.

what is the problem?

Activation operation of Vim Block

Thanks for the great work!
According to the paper, there should be a 'SILU' activation operation in the Vim Block. However, when I check the following code in mamba-1p1p1/mamba_ssm/modules/mamba_simple.py, I didn't find the "activation" operation.

            elif self.bimamba_type == "v2":
                A_b = -torch.exp(self.A_b_log.float())
                out = mamba_inner_fn_no_out_proj(
                    xz,
                    self.conv1d.weight,
                    self.conv1d.bias,
                    self.x_proj.weight,
                    self.dt_proj.weight,
                    A,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D.float(),
                    delta_bias=self.dt_proj.bias.float(),
                    delta_softplus=True,
                )
                out_b = mamba_inner_fn_no_out_proj(
                    xz.flip([-1]),
                    self.conv1d_b.weight,
                    self.conv1d_b.bias,
                    self.x_proj_b.weight,
                    self.dt_proj_b.weight,
                    A_b,
                    None,
                    None,
                    self.D_b.float(),
                    delta_bias=self.dt_proj_b.bias.float(),
                    delta_softplus=True,
                )
                # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
                if not self.if_devide_out:
                    out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
                else:
                    out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2, self.out_proj.weight, self.out_proj.bias)

I further check the mamba_inner_fn_no_out_proj method in MambaInnerFnNoOutProj class in mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py and CausalConv1dFn class incausal-conv1d/causal_conv1d/causal_conv1d_interface.py. Though there is 'activation' operation occurs in CausalConv1dFn, it seems get a 'None' value.

class MambaInnerFnNoOutProj(torch.autograd.Function):
    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
        ...
        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
        ...

class CausalConv1dFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias=None, activation=None):
        if activation not in [None, "silu", "swish"]:
            raise NotImplementedError("activation must be None, silu, or swish")
        if x.stride(2) != 1 and x.stride(1) != 1:
            x = x.contiguous()
        bias = bias.contiguous() if bias is not None else None
        ctx.save_for_backward(x, weight, bias)
        ctx.activation = activation in ["silu", "swish"]
        out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
        return out

pretraining.sh launch script is hardcoded to dev folders/env...can you provide a generic launch script for regular servers?

Can you provide a generic launch script for pre-training with vim?
The readme suggests to run:
bash vim/scripts/pt-vim-t.sh

but the contents of pt-vimt-t.sh have a bunch of local hardcoded paths that do not exist on regular servers ala:
--data-path /share/project/lianghuizhu/datasets/IN1K
source /opt/conda/bin/activate /home/zhulianghui/.conda/envs/deit-dev
etc.

Thanks very much.

pip install -e mamba

I got error: namespace "cub" has no member "WarpMask" when installing the mamba, but I have no idea how to solve it, could you give me some advice? Thanks

LR & Optim setting for downstream task

Dear authors,

THanks to your excellent work.
I would like to use Vim as backbone for downstream task, I need to load the pretraied weight on Imagenet-1K, then finetune the network on downstream-task dataset,
then, what lr_scheduler and optimizer should I use?
Could you please give me some suggestions about the super-params?

TypeError: causal_conv1d_fwd(): incompatible function arguments.

Hi, thanks for the great work!
I encountered the following error:
image
I‘m wondering if conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True) should be conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) in mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py ?

Mamba module initialization

Dear authors,

Thanks for the great work and releasing the codebase! I have one question regarding the Mamba module initialization.

From the code, I see its linear layers are initialized here and here, then the weights and biases are overwritten here and here, which zeros out the mamba linear layers' biases. So this line actually can be deleted? Please share your insights if it's intended. Thanks!

Difference between with/without fused_add_norm

Hi @Unrealluver,

I saw you fused the add and norm operations in the Block class. I'm unsure of the difference between fused_add_norm=True and fused_addn_norm=False. More specifically, can I simply treat fused_add_norm_fn as a integration of following codes?

Vim/vim/models_mamba.py

Lines 75 to 83 in 06c5009

if not self.fused_add_norm:
if residual is None:
residual = hidden_states
else:
residual = residual + self.drop_path(hidden_states)
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)

By the way, have you tried to use LN -> Mixer -> Add like a standard block does? Will it be different compared with Add -> LN -> Mixer in accuracy or speed?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.