Giter VIP home page Giter VIP logo

vim's Introduction

vim's People

Contributors

jingfengyao avatar legendbc avatar unrealluver avatar xinggangw avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

vim's Issues

Issue with selective_scan_cuda

Hi,

I cannot properly select selective_scan_cuda:
How can the error about the unresolved symbol be solved?

Traceback (most recent call last):
File "/rds/general/user/kp4718/home/code/Vim/vim/main.py", line 28, in
import models_mamba
File "/rds/general/user/kp4718/home/code/Vim/vim/models_mamba.py", line 21, in
from mamba_ssm.modules.mamba_simple import Mamba
File "/rds/general/user/kp4718/home/anaconda3/envs/mambaenv/lib/python3.10/site-packages/mamba_ssm/init.py", line 3, in
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
File "/rds/general/user/kp4718/home/anaconda3/envs/mambaenv/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 11, in
import selective_scan_cuda
ImportError: /rds/general/user/kp4718/home/anaconda3/envs/mambaenv/lib/python3.10/site-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb

VIM-S

Any plans for uploading the weights for the vim-s model?

TypeError: causal_conv1d_fwd(): incompatible function arguments.

Hi, thanks for the great work!
I encountered the following error:
image
I‘m wondering if conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True) should be conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, True) in mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py ?

LR & Optim setting for downstream task

Dear authors,

THanks to your excellent work.
I would like to use Vim as backbone for downstream task, I need to load the pretraied weight on Imagenet-1K, then finetune the network on downstream-task dataset,
then, what lr_scheduler and optimizer should I use?
Could you please give me some suggestions about the super-params?

Upload code for segmentor head

Hi,

Could you please upload code for the segmentor head? i dont see them in the codebase, while the paper clearly showed segmentation performance.

Thanks,
Michael

Difference between with/without fused_add_norm

Hi @Unrealluver,

I saw you fused the add and norm operations in the Block class. I'm unsure of the difference between fused_add_norm=True and fused_addn_norm=False. More specifically, can I simply treat fused_add_norm_fn as a integration of following codes?

Vim/vim/models_mamba.py

Lines 75 to 83 in 06c5009

if not self.fused_add_norm:
if residual is None:
residual = hidden_states
else:
residual = residual + self.drop_path(hidden_states)
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)

By the way, have you tried to use LN -> Mixer -> Add like a standard block does? Will it be different compared with Add -> LN -> Mixer in accuracy or speed?

No grad accumulator for a saved leaf

I get an error. ../torch/csrc/autograd/saved_variable.cpp":216, please report a bug to PyTorch. No grad accumulator for a saved leaf.

It's thrown at mamba_ssm\ops\selective_scan_interface.py backward(ctx, dout)

(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, 
         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors

image

Could you please give me some hints? Thank you.

subprocess.CalledProcessError

/home/ai1015/anaconda3/envs/vim/bin/python /mnt/data/ai1015/Vim/vim/main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2
Not using distributed mode
Namespace(gpu=0, batch_size=2, epochs=300, bce_loss=False, unscale_lr=False, model='vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, sched='cosine', lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillation_tau=1.0, cosub=False, finetune='', attn_only=False, data_path='/mnt/data/ai1015/Vim/data/imagenet/', data_set='IMNET', inat_category='name', output_dir='', device='cuda', seed=0, resume='', start_epoch=0, eval=False, eval_crop_ratio=0.875, dist_eval=False, num_workers=10, pin_mem=True, distributed=False, world_size=1, dist_url='env://', if_amp=True, if_continue_inf=False, if_nan2num=False, if_random_cls_token_position=False, if_random_token_rank=False, local_rank=0)
Creating model: vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2
number of params: 6228874
Start training for 300 epochs
/usr/bin/ld: skipping incompatible /lib/i386-linux-gnu/libcuda.so when searching for -lcuda
/usr/bin/ld: skipping incompatible /lib/i386-linux-gnu/libcuda.so when searching for -lcuda
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
Traceback (most recent call last):
File "/mnt/data/ai1015/Vim/vim/main.py", line 550, in
main(args)
File "/mnt/data/ai1015/Vim/vim/main.py", line 482, in main
train_stats = train_one_epoch(
File "/mnt/data/ai1015/Vim/vim/engine.py", line 54, in train_one_epoch
outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/data/ai1015/Vim/vim/models_mamba.py", line 543, in forward
x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
File "/mnt/data/ai1015/Vim/vim/models_mamba.py", line 480, in forward_features
hidden_states, residual = layer(
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/data/ai1015/Vim/vim/models_mamba.py", line 115, in forward
hidden_states, residual = fused_add_norm_fn(
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 83, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench
fn()
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 81, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "", line 63, in _layer_norm_fwd_1pass_kernel
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/compiler/compiler.py", line 425, in compile
so_path = make_stub(name, signature, constants)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/compiler/make_launcher.py", line 39, in make_stub
so = _build(name, src_path, tmpdir)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/common/build.py", line 90, in _build
ret = subprocess.check_call(cc_cmd)
File "/home/ai1015/anaconda3/envs/vim/lib/python3.10/subprocess.py", line 369, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpci9keay7/main.c', '-O3', '-I/home/ai1015/anaconda3/envs/vim/lib/python3.10/site-packages/triton/common/../third_party/cuda/include', '-I/home/ai1015/anaconda3/envs/vim/include/python3.10', '-I/tmp/tmpci9keay7', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmpci9keay7/_layer_norm_fwd_1pass_kernel.cpython-310-x86_64-linux-gnu.so', '-L/lib/x86_64-linux-gnu', '-L/lib/i386-linux-gnu', '-L/lib/i386-linux-gnu']' returned non-zero exit status 1.

what is the problem?

How to test pretrained model on CIFAR dataset and coco2017

hi,

Since the pretrained model's head is 1000 classes, how can I test it on the CIFAR dataset which only has 100 classes?

For the coco2017 dataset, could you provide the data preprocessing code about how to construct it properly for evaluation?
here is my data folder structure for coco dataset:
-dataset
-coco
-annotations
-images
-test
-train

zope-interface>=5

When I run "python setup.py install" in the mamba-1p1p1 folder, I get the following error:
Processing dependencies for mamba-ssm==1.1.1
Searching for zope-interface>=5
Reading https://pypi.org/simple/zope-interface/
No local packages or working download links found for zope-interface>=5
error: Could not find suitable distribution for Requirement.parse('zope-interface>=5')

However, I have installed zope-interface==5 or zope-interface==6.2 and it still gives the same error.
What is the reason for this? How can I fix it?

pretraining.sh launch script is hardcoded to dev folders/env...can you provide a generic launch script for regular servers?

Can you provide a generic launch script for pre-training with vim?
The readme suggests to run:
bash vim/scripts/pt-vim-t.sh

but the contents of pt-vimt-t.sh have a bunch of local hardcoded paths that do not exist on regular servers ala:
--data-path /share/project/lianghuizhu/datasets/IN1K
source /opt/conda/bin/activate /home/zhulianghui/.conda/envs/deit-dev
etc.

Thanks very much.

Training Recipe for Vim-B (base model)

Hello,

Thank you very much for the insightful work detailed in your paper!

The training recipe (e.g. learning rate, scheduler, #of epochs) and results for the Vim-T and Vim-S models were quite informative.

I am actually interested in adapting this approach for the base model. Could you provide guidance on any modifications needed in the training recipe for this purpose?

Additionally, if you might have results for the base model on ImageNet 1K, I would greatly appreciate if you could share them. This would allow me to benchmark my training outcomes against yours for a comprehensive comparison.

Thank you once again for your valuable contributions!

problem about pip install -e causal_conv1d>=1.1.0

I try to install causal_conv1d and use the command:
pip install -e causal_conv1d>=1.1.0
but have this error:
ERROR: causal_conv1d is not a valid editable requirement. It should either be a path to a local project or a VCS URL (beginning with bzr+http, bzr+https, bzr+ssh, bzr+sftp, bzr+ftp, bzr+lp, bzr+file, git+http, git+https, git+ssh, git+git, git+file, hg+file, hg+http, hg+https, hg+ssh, hg+static-http, svn+ssh, svn+http, svn+https, svn+svn, svn+file).

then I use the command :
pip install -e causual-conv1d

but this has another problem:
Installing collected packages: causal-conv1d
Running setup.py develop for causal-conv1d
error: subprocess-exited-with-error

× python setup.py develop did not run successfully.
│ exit code: 1
╰─> [83 lines of output]
    
    
    torch.__version__  = 2.1.1+cu118
    
    
    running develop
    /home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/command/develop.py:40: EasyInstallDeprecationWarning: easy_install command is deprecated.
    !!
    
            ********************************************************************************
            Please avoid running ``setup.py`` and ``easy_install``.
            Instead, use pypa/build, pypa/installer or other
            standards-based tools.
    
            See https://github.com/pypa/setuptools/issues/917 for details.
            ********************************************************************************
    
    !!
      easy_install.initialize_options(self)
    /home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
    !!
    
            ********************************************************************************
            Please avoid running ``setup.py`` directly.
            Instead, use pypa/build, pypa/installer or other
            standards-based tools.
    
            See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
            ********************************************************************************
    
    !!
      self.initialize_options()
    running egg_info
    creating causal_conv1d.egg-info
    writing causal_conv1d.egg-info/PKG-INFO
    writing dependency_links to causal_conv1d.egg-info/dependency_links.txt
    writing requirements to causal_conv1d.egg-info/requires.txt
    writing top-level names to causal_conv1d.egg-info/top_level.txt
    writing manifest file 'causal_conv1d.egg-info/SOURCES.txt'
    reading manifest file 'causal_conv1d.egg-info/SOURCES.txt'
    adding license file 'LICENSE'
    adding license file 'AUTHORS'
    writing manifest file 'causal_conv1d.egg-info/SOURCES.txt'
    running build_ext
    Traceback (most recent call last):
      File "<string>", line 2, in <module>
      File "<pip-setuptools-caller>", line 34, in <module>
      File "/home/wzw/Vim/causal-conv1d/setup.py", line 226, in <module>
        setup(
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/__init__.py", line 103, in setup
        return distutils.core.setup(**attrs)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 185, in setup
        return run_commands(dist)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
        dist.run_commands()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
        self.run_command(cmd)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
        super().run_command(command)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
        cmd_obj.run()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/command/develop.py", line 34, in run
        self.install_for_development()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/command/develop.py", line 109, in install_for_development
        self.run_command('build_ext')
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
        self.distribution.run_command(command)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
        super().run_command(command)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
        cmd_obj.run()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 88, in run
        _build_ext.run(self)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 345, in run
        self.build_extensions()
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 525, in build_extensions
        _check_cuda_version(compiler_name, compiler_version)
      File "/home/wzw/miniconda3/envs/vim/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 413, in _check_cuda_version
        raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda))
    RuntimeError:
    The detected CUDA version (12.2) mismatches the version that was used to compile
    PyTorch (11.8). Please make sure to use the same CUDA versions.
    
    [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.

error: subprocess-exited-with-error

Code to train Vim object detection

Hi everyone, first thing first, kudos to all the research team.

Do you plan to release the object detection training code anytime soon? The memory consumption of this architecture is very compelling..

Issue with Loading Models from Huggingface Repo

Hi,
Thank you for your fantastic project.

I've encountered a problem when trying to evaluate the model downloaded from the Huggingface repo given in your project. I'm facing an issue with both the small and tiny vim models. When I loaded the models, I receive the following error message:
"size mismatch for pos_embed: copying a param with shape torch.Size([1, 730, 384]) from checkpoint, the shape in the current model is torch.Size([1, 197, 384])."

Could you please offer some guidance on how to address this problem? Your assistance would be greatly appreciated.

Thank you so much for this matter!

Run the test code successfully, share my environment.

Hello every researcher,
I have researched this code for 3 day, Finally, I can run the test code successfully. below is the test code.
`import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
print(torch.cuda.is_available())
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape`

Therefore, I would like to share my experiments to help those people are still struggling for the problem of setting environment.
my syetem:
**Ubuntu 22.04
nvidia driver 535 (if you don't know how to upgrade or downgrade your nvidia-driver, you can leave a message here)
CUDA 11.8 (downloaded run file from official website)

python 3.10.13 (installed via conda )
pytorch 2.2.2 (installed via conda)
causal-conv1d-1.1.3.post1( installed via pip install causal-conv1d==1.1.3.post1)
mamba_ssm(installed via author provided file cd mamba-1p1p1/ pip install -e .)**

Hidden State understanding

Hi @xinggangw @Unrealluver @ifzhang @LegendBC ,
Thanks for the release of the great work. But, I was trying to understand how I can actually interpret the model such that the results will be more explainable. Using the pretrained image classification weights, I just want to understand whether during training or inference, how can I record the attention weights generated within the BSSM at each time step. I have gone through the mamba_model.py where we have hidden states and residuals. Can I directly use them for understanding of the model predictions or do I need to extract other details using main.py or Mamba_model.py?If yes, what are the different directions that we can look into like saliency maps etc ,. Any suggestions would be highly useful. Thanks in advance.

pip install -e mamba

I got error: namespace "cub" has no member "WarpMask" when installing the mamba, but I have no idea how to solve it, could you give me some advice? Thanks

Hidden State understanding

Hi @xinggangw @Unrealluver @ifzhang @LegendBC ,
Thanks for the release of the great work. But, I was trying to understand how I can actually interpret the model such that the results will be more explainable. Using the pretrained image classification weights, I just want to understand whether during training or inference, how can I record the attention weights generated within the BSSM at each time step. I have gone through the mamba_model.py where we have hidden states and residuals. Can I directly use them for understanding of the model predictions or do I need to extract other details using main.py or Mamba_model.py?If yes, what are the different directions that we can look into like saliency maps etc ,. Any suggestions would be highly useful. Thanks in advance.

Much lower training efficiency

Thanks for your great works! However, I observe that the training efficiency (including the training speed and memory use) is much lower than that of the plain ViT with a similar mode size. Do you have any insights on this phenomenon?

Mamba module initialization

Dear authors,

Thanks for the great work and releasing the codebase! I have one question regarding the Mamba module initialization.

From the code, I see its linear layers are initialized here and here, then the weights and biases are overwritten here and here, which zeros out the mamba linear layers' biases. So this line actually can be deleted? Please share your insights if it's intended. Thanks!

bimamba version

阅读您的代码时发现bimamba有v1和v2版本分别对应BimambaInnerFn和MambaInnerFnNoOutProj;v1对同一个序列分别进行正向forward和反向forward但是采用了相同的proj模块计算B和C,而v2则是定义了两次proj模块(x_proj于x_proj_b)分别计算不同的B,C,请问为什么要这样做呢?以及这两者哪个效果会更好一点?最后mamba中step()这个函数是不是不管训练还是推理时都不会执行,因为step中并没有指明两次计算的过程?非常感谢。😦

Incompatible to CU118 when installing causal-conv1d and mamba

Thanks a lot to the authors for the wonderful work. I have some quick questions about the repo.

When installing causal-conv1d and mamba, I encountered an error saying "The detected CUDA version (12.3) mismatches the version that was used to compile PyTorch (11.8). Please make sure to use the same CUDA versions.". Do you know how to address it?

I also tried several different torch and cuda versions (eg tc211+cu121), it seems that sometimes it can be successfully installed but the model training become quite slow. Did you find similar issues about that, or does this repo have a strict constraint of the torch and cuda versions?

Question about the design of gate Z;

Hello,
I would like to express my appreciation for the outstanding work on this project.

In the original Mamba, there is no existence of the 'z' mechanism as a gate. However, in Vim, an additional 'z' has been incorporated as a gate. Why was it designed this way? What would happen if it were removed?

The code in question can be found at:

y = y * self.act(z) # (B D)

Illustration

Your clarification would be immensely appreciated.

Memory/speed improvements over DeiT for larger Vim

I find your paper on Vision Mamba very interesting. However, when using your code, I encountered a problem (which may well be normal behavior). When analyzing GPU memory consumption and FPS for Vim versions other than Tiny, I could not achieve similar speed and memory improvements. I compared it to DeiT, and the improvements were only visible in Vim-Ti. Am I doing something wrong, or are the improvements only in the Tiny version?

Activation operation of Vim Block

Thanks for the great work!
According to the paper, there should be a 'SILU' activation operation in the Vim Block. However, when I check the following code in mamba-1p1p1/mamba_ssm/modules/mamba_simple.py, I didn't find the "activation" operation.

            elif self.bimamba_type == "v2":
                A_b = -torch.exp(self.A_b_log.float())
                out = mamba_inner_fn_no_out_proj(
                    xz,
                    self.conv1d.weight,
                    self.conv1d.bias,
                    self.x_proj.weight,
                    self.dt_proj.weight,
                    A,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D.float(),
                    delta_bias=self.dt_proj.bias.float(),
                    delta_softplus=True,
                )
                out_b = mamba_inner_fn_no_out_proj(
                    xz.flip([-1]),
                    self.conv1d_b.weight,
                    self.conv1d_b.bias,
                    self.x_proj_b.weight,
                    self.dt_proj_b.weight,
                    A_b,
                    None,
                    None,
                    self.D_b.float(),
                    delta_bias=self.dt_proj_b.bias.float(),
                    delta_softplus=True,
                )
                # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
                if not self.if_devide_out:
                    out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
                else:
                    out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2, self.out_proj.weight, self.out_proj.bias)

I further check the mamba_inner_fn_no_out_proj method in MambaInnerFnNoOutProj class in mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py and CausalConv1dFn class incausal-conv1d/causal_conv1d/causal_conv1d_interface.py. Though there is 'activation' operation occurs in CausalConv1dFn, it seems get a 'None' value.

class MambaInnerFnNoOutProj(torch.autograd.Function):
    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
        ...
        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
        ...

class CausalConv1dFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias=None, activation=None):
        if activation not in [None, "silu", "swish"]:
            raise NotImplementedError("activation must be None, silu, or swish")
        if x.stride(2) != 1 and x.stride(1) != 1:
            x = x.contiguous()
        bias = bias.contiguous() if bias is not None else None
        ctx.save_for_backward(x, weight, bias)
        ctx.activation = activation in ["silu", "swish"]
        out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
        return out

License

Hi, thanks for releasing! Any plans to add a license? Thx!!

Cannot use pretrained model

Hello,

I am running this command on an Ubuntu 22.04 LTS system python main.py --finetune '' --eval --resume vim_t_midclstok_76p1acc.pth--model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --data-path imagenet-mini --data-set 'IMNET' in order to try your model, as you suggest in the documentation.

I installed mamba-ssm using pip wheels because your recommended approach does not work, as noted in another issue.

However, it seems that the pretrained model is not loaded correctly as shown in the error message below. What could I do to fix this?

Traceback (most recent call last):
File "/rds/general/user/kp4718/home/code/Vim/vim/main.py", line 545, in
main(args)
File "/rds/general/user/kp4718/home/code/Vim/vim/main.py", line 448, in main
model_without_ddp.load_state_dict(checkpoint['model'])
File "/rds/general/user/kp4718/home/anaconda3/envs/mambaenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VisionMamba:
Unexpected key(s) in state_dict: "layers.0.mixer.A_b_log", "layers.0.mixer.D_b", "layers.0.mixer.conv1d_b.weight", "layers.0.mixer.conv1d_b.bias", "layers.0.mixer.x_proj_b.weight", "layers.0.mixer.dt_proj_b.weight", "layers.0.mixer.dt_proj_b.bias", "layers.1.mixer.A_b_log", "layers.1.mixer.D_b", "layers.1.mixer.conv1d_b.weight", "layers.1.mixer.conv1d_b.bias", "layers.1.mixer.x_proj_b.weight", "layers.1.mixer.dt_proj_b.weight", "layers.1.mixer.dt_proj_b.bias", "layers.2.mixer.A_b_log", "layers.2.mixer.D_b", "layers.2.mixer.conv1d_b.weight", "layers.2.mixer.conv1d_b.bias", "layers.2.mixer.x_proj_b.weight", "layers.2.mixer.dt_proj_b.weight", "layers.2.mixer.dt_proj_b.bias", "layers.3.mixer.A_b_log", "layers.3.mixer.D_b", "layers.3.mixer.conv1d_b.weight", "layers.3.mixer.conv1d_b.bias", "layers.3.mixer.x_proj_b.weight", "layers.3.mixer.dt_proj_b.weight", "layers.3.mixer.dt_proj_b.bias", "layers.4.mixer.A_b_log", "layers.4.mixer.D_b", "layers.4.mixer.conv1d_b.weight", "layers.4.mixer.conv1d_b.bias", "layers.4.mixer.x_proj_b.weight", "layers.4.mixer.dt_proj_b.weight", "layers.4.mixer.dt_proj_b.bias", "layers.5.mixer.A_b_log", "layers.5.mixer.D_b", "layers.5.mixer.conv1d_b.weight", "layers.5.mixer.conv1d_b.bias", "layers.5.mixer.x_proj_b.weight", "layers.5.mixer.dt_proj_b.weight", "layers.5.mixer.dt_proj_b.bias", "layers.6.mixer.A_b_log", "layers.6.mixer.D_b", "layers.6.mixer.conv1d_b.weight", "layers.6.mixer.conv1d_b.bias", "layers.6.mixer.x_proj_b.weight", "layers.6.mixer.dt_proj_b.weight", "layers.6.mixer.dt_proj_b.bias", "layers.7.mixer.A_b_log", "layers.7.mixer.D_b", "layers.7.mixer.conv1d_b.weight", "layers.7.mixer.conv1d_b.bias", "layers.7.mixer.x_proj_b.weight", "layers.7.mixer.dt_proj_b.weight", "layers.7.mixer.dt_proj_b.bias", "layers.8.mixer.A_b_log", "layers.8.mixer.D_b", "layers.8.mixer.conv1d_b.weight", "layers.8.mixer.conv1d_b.bias", "layers.8.mixer.x_proj_b.weight", "layers.8.mixer.dt_proj_b.weight", "layers.8.mixer.dt_proj_b.bias", "layers.9.mixer.A_b_log", "layers.9.mixer.D_b", "layers.9.mixer.conv1d_b.weight", "layers.9.mixer.conv1d_b.bias", "layers.9.mixer.x_proj_b.weight", "layers.9.mixer.dt_proj_b.weight", "layers.9.mixer.dt_proj_b.bias", "layers.10.mixer.A_b_log", "layers.10.mixer.D_b", "layers.10.mixer.conv1d_b.weight", "layers.10.mixer.conv1d_b.bias", "layers.10.mixer.x_proj_b.weight", "layers.10.mixer.dt_proj_b.weight", "layers.10.mixer.dt_proj_b.bias", "layers.11.mixer.A_b_log", "layers.11.mixer.D_b", "layers.11.mixer.conv1d_b.weight", "layers.11.mixer.conv1d_b.bias", "layers.11.mixer.x_proj_b.weight", "layers.11.mixer.dt_proj_b.weight", "layers.11.mixer.dt_proj_b.bias", "layers.12.mixer.A_b_log", "layers.12.mixer.D_b", "layers.12.mixer.conv1d_b.weight", "layers.12.mixer.conv1d_b.bias", "layers.12.mixer.x_proj_b.weight", "layers.12.mixer.dt_proj_b.weight", "layers.12.mixer.dt_proj_b.bias", "layers.13.mixer.A_b_log", "layers.13.mixer.D_b", "layers.13.mixer.conv1d_b.weight", "layers.13.mixer.conv1d_b.bias", "layers.13.mixer.x_proj_b.weight", "layers.13.mixer.dt_proj_b.weight", "layers.13.mixer.dt_proj_b.bias", "layers.14.mixer.A_b_log", "layers.14.mixer.D_b", "layers.14.mixer.conv1d_b.weight", "layers.14.mixer.conv1d_b.bias", "layers.14.mixer.x_proj_b.weight", "layers.14.mixer.dt_proj_b.weight", "layers.14.mixer.dt_proj_b.bias", "layers.15.mixer.A_b_log", "layers.15.mixer.D_b", "layers.15.mixer.conv1d_b.weight", "layers.15.mixer.conv1d_b.bias", "layers.15.mixer.x_proj_b.weight", "layers.15.mixer.dt_proj_b.weight", "layers.15.mixer.dt_proj_b.bias", "layers.16.mixer.A_b_log", "layers.16.mixer.D_b", "layers.16.mixer.conv1d_b.weight", "layers.16.mixer.conv1d_b.bias", "layers.16.mixer.x_proj_b.weight", "layers.16.mixer.dt_proj_b.weight", "layers.16.mixer.dt_proj_b.bias", "layers.17.mixer.A_b_log", "layers.17.mixer.D_b", "layers.17.mixer.conv1d_b.weight", "layers.17.mixer.conv1d_b.bias", "layers.17.mixer.x_proj_b.weight", "layers.17.mixer.dt_proj_b.weight", "layers.17.mixer.dt_proj_b.bias", "layers.18.mixer.A_b_log", "layers.18.mixer.D_b", "layers.18.mixer.conv1d_b.weight", "layers.18.mixer.conv1d_b.bias", "layers.18.mixer.x_proj_b.weight", "layers.18.mixer.dt_proj_b.weight", "layers.18.mixer.dt_proj_b.bias", "layers.19.mixer.A_b_log", "layers.19.mixer.D_b", "layers.19.mixer.conv1d_b.weight", "layers.19.mixer.conv1d_b.bias", "layers.19.mixer.x_proj_b.weight", "layers.19.mixer.dt_proj_b.weight", "layers.19.mixer.dt_proj_b.bias", "layers.20.mixer.A_b_log", "layers.20.mixer.D_b", "layers.20.mixer.conv1d_b.weight", "layers.20.mixer.conv1d_b.bias", "layers.20.mixer.x_proj_b.weight", "layers.20.mixer.dt_proj_b.weight", "layers.20.mixer.dt_proj_b.bias", "layers.21.mixer.A_b_log", "layers.21.mixer.D_b", "layers.21.mixer.conv1d_b.weight", "layers.21.mixer.conv1d_b.bias", "layers.21.mixer.x_proj_b.weight", "layers.21.mixer.dt_proj_b.weight", "layers.21.mixer.dt_proj_b.bias", "layers.22.mixer.A_b_log", "layers.22.mixer.D_b", "layers.22.mixer.conv1d_b.weight", "layers.22.mixer.conv1d_b.bias", "layers.22.mixer.x_proj_b.weight", "layers.22.mixer.dt_proj_b.weight", "layers.22.mixer.dt_proj_b.bias", "layers.23.mixer.A_b_log", "layers.23.mixer.D_b", "layers.23.mixer.conv1d_b.weight", "layers.23.mixer.conv1d_b.bias", "layers.23.mixer.x_proj_b.weight", "layers.23.mixer.dt_proj_b.weight", "layers.23.mixer.dt_proj_b.bias".

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.