Giter VIP home page Giter VIP logo

transseg's Introduction

Adapting Pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation

MIT license Python 3.8 Pytorch Pytorch Lightning Black

This repo provides the PyTorch source code of our paper: Adapting Pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation. [Paper] [Supp] [Poster]

Abstract

Given the prevalence of 3D medical imaging technologies such as MRI and CT that are widely used in diagnosing and treating diverse diseases, 3D segmentation is one of the fundamental tasks of medical image analysis. Recently, Transformer-based models have started to achieve state-of-the-art performances across many vision tasks, through pre-training on large-scale natural image benchmark datasets. While works on medical image analysis have also begun to explore Transformer-based models, there is currently no optimal strategy to effectively leverage pre-trained Transformers, primarily due to the difference in dimensionality between 2D natural images and 3D medical images. Existing solutions either split 3D images into 2D slices and predict each slice independently, thereby losing crucial depth-wise information, or modify the Transformer architecture to support 3D inputs without leveraging pre-trained weights. In this work, we use a simple yet effective weight inflation strategy to adapt pre-trained Transformers from 2D to 3D, retaining the benefit of both transfer learning and depth information. We further investigate the effectiveness of transfer from different pre-training sources and objectives. Our approach achieves state-of-the-art performances across a broad range of 3D medical image datasets, and can become a standard strategy easily utilized by all work on Transformer-based models for 3D medical images, to maximize performance.

Approach

Figure: Approach overview. Large-scale pre-trained Transformers are used as the encoder in the segmentation model for transfer learning, in which weights are adapted using the inflation strategy to support 3D inputs. Each 3D image is split into windows, which contain a small number of neighbor slices. Each window is fed into the segmentation model and the segmentation of the center slice is predicted. All the predicted slices are aggregated to form the final 3D prediction.

Getting Started

Installation

Please install dependencies by

conda env create -f environment.yml

Dataset and Model

  • We use the BCV, ACDC, MSD dataset, please register and download these datasets.
  • Decompress each dataset and move it to the corresponding the src/data/ folder (e.g., move BCV to src/data/bcv/)
  • Run the pre-processing script split_data_to_slices_nii.py in each folder to generate processed data. (e.g., BCV will be processed to src/data/bcv/processed/).
  • Run the weight downloading script download_weights.sh in the src/backbones/encoders/pretrained_models/folder.

Training

  1. To train the best segmentation model with both transfer learning and depth information (i.e., Ours in the paper), run:
cd src/
bash scripts/train_[dataset].sh

All the hardward requirements for training such as number of GPUs, CPUs, RAMs are listed in each script.

To change the different encoder, set --encoder swint/videoswint/dino.

To use your own dataset, modify --data_dir

To adjust the number of training steps, modify --max_steps.

  1. To train the segmentation model with only transfer learning and without depth information (i.e., Ours w/o D in the paper), simply add --force_2d 1 and run:
cd src/
# add `--force_2d 1` at the end of the script
bash scripts/train_[dataset].sh
  1. To train the segmentation model without transfer learning or depth information (i.e., Ours w/o T&D in the paper), simply add --use_pretrained 0 and run:
cd src/
# add both `--force_2d 1` and `--use_pretrained 0` at the end of the script
bash scripts/train_[dataset].sh

Results are displayed at the end of training and can also be found at wandb/ (open with wandb web or local client).

Model files are saved in MedicalSegmentation/ folder.

  1. To reproduce UNETR baseline, follow the tutorial of the official UNETR release and set the correct training and validation set:
https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb

Will clean and integrate UNETR codes in the final code release.

Evaluation

To evaluate any trained segmentation model, simply add --evaluation 1 and --model_path <path/to/checkpoint> and run:

cd src/
# add `--evaluation 1` and `--model_path <path/to/checkpoint>` at the end of the script
bash scripts/train_[dataset].sh

Results are displayed at the end of evaluation and can also be found at wandb/ (open with wandb web or local client).

The predictions are saved in dumps/ (open with pickle.Unpickler) .

Compute FLOPS

To compute the FLOPS of the vanilla Transformer and inflated Transformer, simply run:

cd src/
python compute_flops.py

Should get 2D FLOPS: 213343401984 and 3D FLOPS: 214954014720, which indicates there is little increased computational cost of our method.

Citation

If you use this repo in your research, please kindly cite it as follows:

@inproceedings{zhang2022adapting,
  title={Adapting Pre-trained Vision Transformers from 2D to 3D through Weight Inflation Improves Medical Image Segmentation},
  author={Zhang, Yuhui and Huang, Shih-Cheng and Zhou, Zhengping and Lungren, Matthew P and Yeung, Serena},
  booktitle={Machine Learning for Health},
  pages={391--404},
  year={2022},
  organization={PMLR}
}

transseg's People

Contributors

yuhui-zh15 avatar juampatronics avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.