Giter VIP home page Giter VIP logo

mambamil's Introduction

MambaMIL: Enhancing Long Sequence Modeling with Sequence Reordering in Computational Pathology

License: MIT GitHub last commit GitHub issues GitHub stars Arxiv Page

NEWS

2024-05-14: Our paper got early acceptance by MICCAI 2024!!!

Abstract

Multiple Instance Learning (MIL) has emerged as a dominant paradigm to extract discriminative feature representations within Whole Slide Images (WSIs) in computational pathology. Despite driving notable progress, existing MIL approaches suffer from limitations in facilitating comprehensive and efficient interactions among instances, as well as challenges related to time-consuming computations and overfitting. In this paper, we incorporate the Selective Scan Space State Sequential Model (Mamba) in Multiple Instance Learning (MIL) for long sequence modeling with linear complexity, termed as MambaMIL. By inheriting the capability of vanilla Mamba, MambaMIL demonstrates the ability to comprehensively understand and perceive long sequences of instances. Furthermore, we propose the Sequence Reordering Mamba (SR-Mamba) aware of the order and distribution of instances, which exploits the inherent valuable information embedded within the long sequences. With the SR-Mamba as the core component, MambaMIL can effectively capture more discriminative features and mitigate the challenges associated with overfitting and high computational overhead. Extensive experiments on two public challenging tasks across nine diverse datasets demonstrate that our proposed framework performs favorably against state-of-the-art MIL methods.

NOTES

2024-04-12: For subsequent updates of the paper, We will update the arixv version in next month.

2024-04-13: We released the model of MambaMIL. The whole training code is coming soon.

2024-04-24: We released the full version of MambaMIL, including models and train scripts.

Installation

  • Environment: CUDA 11.8 / Python 3.10
  • Create a virtual environment
> conda create -n mambamil python=3.10 -y
> conda activate mambamil
  • Install Pytorch 2.0.1
> pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
> pip install packaging
  • Install causal-conv1d
> pip install causal-conv1d==1.1.1
  • Install Mamba
> git clone [email protected]:isyangshu/MambaMIL.git
> cd mamba
> pip install .
  • Other requirements
> pip install scikit-survival==0.22.2
> pip install pandas==2.2.1
> pip install tensorboardx
> pip install h5py
> pip install wandb
> pip install tensorboard
> pip install lifelines

Repository Details

  • splits: Splits for reproducation.
  • train_scripts: We provide train scripts for cancer subtyping and survival prediction.

How to Train

Prepare your data

  1. Download diagnostic WSIs from TCGA and BRACS
  2. Use the WSI processing tool provided by CLAM to extract resnet-50 and PLIP pretrained feature for each 512 $\times$ 512 patch (20x), which we then save as .pt files for each WSI. So, we get one pt_files folder storing .pt files for all WSIs of one study.

The final structure of datasets should be as following:

DATA_ROOT_DIR/
    └──pt_files/
        └──resnet50/
            ├── slide_1.pt
            ├── slide_2.pt
            └── ...
        └──plip/
            ├── slide_1.pt
            ├── slide_2.pt
            └── ...
        └──others/
            ├── slide_1.pt
            ├── slide_2.pt
            └── ...

Survival Prediction

We provide train scripts for survival prediction ALL_512_surivial_k_fold.sh.

Below are the supported models and datasets:

model_names='max_mil mean_mil att_mil trans_mil s4_mil mamba_mil'
backbones="resnet50 plip"
cancers='BLCA BRCA COADREAD KIRC KIRP LUAD STAD UCEC'

run the following code for training

sh ./train_scripts/ALL_512_surivial_k_fold.sh

Cancer Subtyping

We provide train scripts for TCGA NSCLC cancer subtyping LUAD_LUSC_512_subtyping.sh and BReAst Carcinoma Subtyping BRACS.sh.

Below are the supported models:

model_names='max_mil mean_mil att_mil trans_mil s4_mil mamba_mil'
backbones="resnet50 plip"

run the following code for training TCGA NSCLC cancer subtyping

sh ./train_scripts/LUAD_LUSC_512_subtyping.sh

run the following code for training BReAst Carcinoma Subtyping

sh ./train_scripts/BRACS.sh

Acknowledgements

Huge thanks to the authors of following open-source projects:

License & Citation

If you find our work useful in your research, please consider citing our paper at:

@article{yang2024mambamil,
  title={MambaMIL: Enhancing Long Sequence Modeling with Sequence Reordering in Computational Pathology},
  author={Yang, Shu and Wang, Yihui and Chen, Hao},
  journal={arXiv preprint arXiv:2403.06800},
  year={2024}
}

This code is available for non-commercial academic purposes. If you have any question, feel free to email Shu YANG and Yihui WANG.

mambamil's People

Contributors

wyhsleep avatar isyangshu avatar

Stargazers

Nikos Giakoumoglou avatar STYLIANOS IORDANIS avatar  avatar  avatar Zisong Wang avatar  avatar Liu Xiaoping avatar lbjcelsius avatar Kevin avatar chaosheng9505 avatar huihui1999 avatar  avatar opteroncx avatar Toby avatar Zijie Fang avatar  avatar Zhengrui Guo avatar  avatar  avatar Ddong Sun avatar Zhengting Jiang avatar  avatar Charlie Saillard avatar Lanfeng Zhong avatar Anh Tien Nguyen avatar Yunyao_Shen avatar Vatsal Patel avatar 三水 avatar  avatar  avatar huang avatar  avatar  avatar Linshan avatar Minghong Duan avatar Yi LIN avatar Yang Zekang avatar Chenyang Yu avatar chenxin avatar  avatar 赵骁骐 avatar 하린 avatar Sapere Aude avatar Gijs Smit avatar  avatar Dagoberto Pulido Arias avatar Li Pan avatar  avatar  avatar  avatar  avatar junjianli avatar  avatar worldNebula avatar  avatar  avatar Shengjia Chen avatar Ellery Queen avatar

Watchers

worldNebula avatar Malik Hashmat avatar Shubham Innani avatar  avatar  avatar

mambamil's Issues

Patch Size and Evaluation code

This is an amazing repo. I have few questions.

  1. Why do we need a patch size of 512, can't we use some other size?
  2. Do you intend to release evaluation code and heatmap scripts for the repository?
  3. Can we change the feature extractor?

Thanks,
Shubham

About nn.SiLU(), is the model structure drawn incorrectly?

I noticed that the authors have drawn nn.SiLU() in the model structure. However, I found that no nn.SiLU() plays a role in the model. line 109 of srmamba.py self.act = nn.SiLU() commented out doesn't report an error. lines 200-232 of srmamba.py, including mamba_inner_fn_no_out_proj, don't use nn. SiLU(). As I understand it, nn.SiLU() in Vision Mamba is done via Block encapsulation. So can I assume that this is a model structure drawing error in MambaMIL?

Single WSI in training?

Hi,

Thanks for the amazing work! I was wondering if we have multiple slides per patient, did you drop the the other slides and just use a single slide in training?

How did you handle this?

Did you drop the multiple slides from each case?

patients_df = slide_data.drop_duplicates(['case_id']).copy()

patients_df = slide_data.drop_duplicates(['case_id'])

Thanks,
Shubham

What do the outputs represent

Dear isyangshu,
I hope this message finds you well. I recently came across your repository MambaMIL on GitHub and found the work particularly interesting. However, I am having trouble understanding the meaning of the outputs in MambaMIL. Could you tell me what do the outputs(hazards, S) represent?
Thank you very much for your time and assistance. I look forward to your response.

About Training GPU Memory

Dear Author:
Question 1: I have observed that more GPU memory is used when using your transmil code than the original code (https://github.com/szc19990412/TransMIL), why is this?

Question 2: In my opinion, on some datasets whose number of bags varies a lot, the GPU memory value should jump with the training. But in fact this is not the case, the GPU memory value is very stable, can you tell me why?

Looking forward to your reply.

Training and Inference code

Hi,
Extremely wonderful work! I am curious to know if the training and inference code for this repository will be released anytime soon?

I am excited to implement this in our work. Thanks in advance.

Order of patch sequence in Mamba

Thank you for your work.

What is the order of patch sequences used in the Table 3 ?

image

I mean how the patch feature sequences were organized when being feed into the Mamba blocks ?

Thank you.

May I obtain one .pt file of one image from BRACS

Congratulations on your early acceptance by MICCAI!

I understand that the images were processed at 512x512 (20x), with a batch size of 1. If I am correct, each image was partitioned into 20 patches, each sized 512x512. Could I please obtain a .pt file for one image? I would like to feed this .pt file into the model to examine how each part of the model works.

Thank you!

Run MambaMIL multiple times, and each time the results were different on Camelyon16

@wyhsleep @isyangshu
Hello, did you test your model on Camelyon16? I ran it multiple times, and the results were different.
When I run the setup.py to install mamba_ssm, the result of acc and AUC will be around 92% and 95%. But when I put it locally (just like the form you provided, put the folder directly in the project), the acc and AUC results were only around 87% and 93%. Of course, whether install it or just put it locally, the results between different run times, the results also different, even if the random_seed is the same.
The code I'm using is the MambaMIL.py and mamba folders you uploaded the first time. The train code uses the TransMIL code.
Looking forward to your reply.

CLAM Preprocessing

I want to ask about the parameters used to obtain the feature vectors from CLAM.

I'm referring specifically to the BLCA dataset.
If I use the default parameters, some WSIs have less than 437 patches. 437 is the number of patches sampled for survival analysis in Table 4.
This is an example output log:
batch 99, loss: 0.8604, label: 3, event_time: 29.5300, risk: -2.5116, bag_size: 274
batch 199, loss: 0.2883, label: 2, event_time: 18.7900, risk: -3.0559, bag_size: 313
Epoch: 0, train_loss_surv: 1.4153, train_loss: 1.4153, train_c_index: 0.4685
Epoch: 0, val_loss_surv: 1.3301, val_loss: 1.3301, val_c_index: 0.4093
Validation loss decreased (inf --> 0.409305). Saving model ...

batch 99, loss: 0.1626, label: 2, event_time: 21.0200, risk: -3.2254, bag_size: 335
batch 199, loss: 1.8531, label: 0, event_time: 6.7700, risk: -2.7776, bag_size: 268
Epoch: 1, train_loss_surv: 1.2493, train_loss: 1.2493, train_c_index: 0.5632
Epoch: 1, val_loss_surv: 1.2451, val_loss: 1.2451, val_c_index: 0.4983
Validation loss decreased (0.409305 --> 0.498314). Saving model ...

batch 99, loss: 0.2588, label: 1, event_time: 11.9600, risk: -2.6213, bag_size: 292
batch 199, loss: 1.1709, label: 1, event_time: 8.9400, risk: -2.2951, bag_size: 206
Epoch: 2, train_loss_surv: 1.1817, train_loss: 1.1817, train_c_index: 0.5791
Epoch: 2, val_loss_surv: 1.2481, val_loss: 1.2481, val_c_index: 0.5563
Validation loss decreased (0.498314 --> 0.556305). Saving model ...

where is 'Sequence Reordering Operation' in the codebase

Hello, author! I wanted to express my appreciation for your work; I find it truly impressive. In the paper, you introduced the 'Sequence Reordering Operation.' Could you kindly direct me to the specific location within the codebase where this technique is implemented? I'm interested in knowing the exact line and file in the repository. Thank you!

Aggregate results of Mamba

Why did you use an attention operation to generate the slide features from Mamba's output ?
image

I think the traditional Mamba use average pooling to aggregate the sequence.

How to Generate Balanced Splits

If I am using a set of WSIs that is slightly different from yours, how can I come up with splits for 5-fold validation that are balanced in terms of classes?

Downloading TCGA data

I have a question on downloading TCGA data. I used the gdc command line tool to download the WSIs for BRCA. However, a significant number of svs files did not download correctly, resulting in a lot of files with a 'svs.partial' file extension.

Is there a way to avoid this issue so that I can use the max number of WSIs?

some problems(the number of block for the proposed method, the overfitting issues) from paper.

hello authors, this is very responsive work based on the Mamba, I appreciate the contribution form the authors.
But I have some questions regarding the paper,
1.in the paper, the author compare the Mamba and Vim, I am curious the hyperparameter( e.g. the number of block for the proposed method) the vim is 24 according to I know.
2. The authors state in the paper, the TransMIL has series problem of overfitting, according to I know, the overfitting : the result of training is much larger than the result of testing. but I hadn't seen the comparison between the result of training and testing.

Please point out if I am wrong.
At last, thank you a lot for your wonderful work.

如何传参数获得 512*512 patch (20x)呢?

尊敬的作者,

   您好!请问 CLAM 在 处理 LUAD 数据集时为了获得 512*512 patch (20x),create_patches_fp.py和extract_features_fp.py的参数应该怎么设置呢?

期待您的回复。

python create_patches_fp.py --source DATA_DIRECTORY --save_dir RESULTS_DIRECTORY --patch_size 256 --seg --patch --stitch
CUDA_VISIBLE_DEVICES=0 python extract_features_fp.py --data_h5_dir DIR_TO_COORDS --data_slide_dir DATA_DIRECTORY --csv_path CSV_FILE_NAME --feat_dir FEATURES_DIRECTORY --batch_size 512 --slide_ext .svs

The perplexing experimental results

In the Cancer Subtyping experiment, the training results using features extracted by PLIP seem even worse than those using features extracted by Resnet50. Moreover, I observed that simply using Max-Pooling appears to outperform most of the comparative methods. I want to ask if you truly conducted your experiments with the necessary rigor and seriousness?
屏幕截图 2024-08-28 233826

Considering an improvement with mamba2?

Hello author, I have seen that your BiMamba code is an improvement on the basis of Mamba1 (similar to Vision Mamba), Mamba2 has been released recently, may I ask if you have considered using Mamba2 to improve it? Can you share the idea of BiMamba2?

代码开源时间

您好,对您的这些工作非常感兴趣,请问什么时候开源?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.