Giter VIP home page Giter VIP logo

maxvit-unet's Introduction

MaxViT-UNet: Multi-Axis Attention for Medical Image Segmentation

Abstract

Convolutional neural networks have made significant strides in medical image analysis in recent years. However, the local nature of the convolution operator inhibits the CNNs from capturing global and long-range interactions. Recently, Transformers have gained popularity in the computer vision community and also medical image segmentation. But scalability issues of self-attention mechanism and lack of the CNN like inductive bias have limited their adoption. In this work, we present MaxViT-UNet, an Encoder-Decoder based hybrid vision transformer for medical image segmentation. The proposed hybrid decoder, also based on MaxViT-block, is designed to harness the power of convolution and self-attention mechanism at each decoding stage with minimal computational burden. The multi-axis self-attention in each decoder stage helps in differentiating between the object and background regions much more efficiently. The hybrid decoder block initially fuses the lower level features upsampled via transpose convolution, with skip-connection features coming from hybrid encoder, then fused features are refined using multi-axis attention mechanism. The proposed decoder block is repeated multiple times to accurately segment the nuclei regions. Experimental results on MoNuSeg dataset proves the effectiveness of the proposed technique. Our MaxViT-UNet outperformed the previous CNN only (UNet) and Transformer only (Swin-UNet) techniques by a large margin of 2.36% and 5.31% on Dice metric respectively. (arxiv)

Installation

Important versions:

python==3.10.4
cudatoolkit==11.3.1
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.12
openmim==0.2.1
mmcls==0.23.2
mmcv-full==1.6.0
mmsegmentation==0.24.1

Complete environment information is in environment.yml file

Step 0. Download and install Miniconda from the official website.

Step 1. Create a conda environment and activate it.

conda create --name openmmlab python=3.10.4 -y
conda activate openmmlab

Step 2. Install PyTorch using following command.

conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch

Step 3. Install MMCV using MIM.

pip3 install -U openmim
mim install mmcv-full

Step 4. Clone this repository.

git clone https://github.com/abdul2706/maxvit_unet.git
cd maxvit_unet
pip3 install -v -e .

For more details refer to get_started.md by MMSegmentation.

Data

Refer to the MoNuSeg18 for dataset related information. To prepare dataset in MMSegmentation style format see dataset_prepare.md.

Training

Before training MaxViT-UNet, change the value of MMSEG_HOME_PATH, and DATASET_HOME_PATH in the config file. Then use following commands:

# single-gpu training
python3 tools/train.py "configs/maxvit_unet/maxvit_unet_s1.py"

# multi-gpu training
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 tools/dist_train.sh "configs/maxvit_unet/maxvit_unet_s1.py" 4

Inference

# single-gpu testing
python tools/test.py "configs/maxvit_unet/maxvit_unet_s1.py" "path/to/checkpoint.pth" --eval mDice

# multi-gpu testing
tools/dist_test.sh "configs/maxvit_unet/maxvit_unet_s1.py" "path/to/checkpoint.pth" <GPU_NUM> --eval mDice

Results

Model Dice IoU
UNet 0.8185 0.6927
Swin-UNet 0.7956 0.6471
MaxViT-UPerNet 0.8176 0.6914
MaxViT-UNet 0.8378 0.7208

MaxViT-UNet weights trained on MoNuSeg18 can be downloaded from here.


Comparison Curves

Qualitative Comparison

Citation

@misc{rehman2023maxvitunet,
      title={MaxViT-UNet: Multi-Axis Attention for Medical Image Segmentation}, 
      author={Abdul Rehman and Asifullah Khan},
      year={2023},
      eprint={2305.08396},
      archivePrefix={arXiv},
      primaryClass={eess.IV}
}

maxvit-unet's People

Contributors

mengzhangli avatar xvjiarui avatar junjun2016 avatar meowzheng avatar rockeycoss avatar xiexinch avatar sennnnn avatar yamengxi avatar linfangjian01 avatar sshuair avatar daavoo avatar abdul2706 avatar grimoire avatar jinwonkim93 avatar lkm2835 avatar hellock avatar johnzja avatar freywang avatar nourollah avatar siddancha avatar wuziyi616 avatar andife avatar drcut avatar yinchimaoliang avatar edwardyehuang avatar sbcv avatar xiaojianzhong avatar uni19 avatar innerlee avatar congee524 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.