Giter VIP home page Giter VIP logo

rpt's Introduction

RPT

Official code for paper: Few-Shot Medical Image Segmentation via a Region-enhanced Prototypical Transformer

Abstract

Automated segmentation of large volumes of medical images is often plagued by the limited availability of fully annotated data and the diversity of organ surface properties resulting from the use of different acquisition protocols for different patients. In this paper, we introduce a more promising few-shot learning-based method named Region-enhanced Prototypical Transformer (RPT) to mitigate the effects of large intra-class diversity/bias. First, a subdivision strategy is introduced to produce a collection of regional prototypes from the foreground of the support prototype. Second, a self-selection mechanism is proposed to incorporate into the Bias-alleviated Transformer (BaT) block to suppress or remove interferences present in the query prototype and regional support prototypes. By stacking BaT blocks, the proposed RPT can iteratively optimize the generated regional prototypes and finally produce rectified and more accurate global prototypes for Few-Shot Medical Image Segmentation (FSMS). Extensive experiments are conducted on three publicly available medical image datasets, and the obtained results show consistent improvements compared to state-of-the-art FSMS methods.

Getting started

Dependencies

Please install following essential dependencies:

dcm2nii
json5==0.8.5
jupyter==1.0.0
nibabel==2.5.1
numpy==1.22.0
opencv-python==4.5.5.62
Pillow>=8.1.1
sacred==0.8.2
scikit-image==0.18.3
SimpleITK==1.2.3
torch==1.10.2
torchvision=0.11.2
tqdm==4.62.3

Pre-processing is performed according to Ouyang et al. and we follow the procedure on their github repository.

The trained models can be downloaded by:

  1. trained models for CHAOS under Setting 1
  2. trained models for CHAOS under Setting 2
  3. trained models for SABS under Setting 1
  4. trained models for SABS under Setting 2
  5. trained models for CMR

The pre-processed data and supervoxels can be downloaded by:

  1. Pre-processed CHAOS-T2 data and supervoxels
  2. Pre-processed SABS data and supervoxels
  3. Pre-processed CMR data and supervoxels

Training

  1. Compile ./supervoxels/felzenszwalb_3d_cy.pyx with cython (python ./supervoxels/setup.py build_ext --inplace) and run ./supervoxels/generate_supervoxels.py
  2. Download pre-trained ResNet-101 weights vanilla version or deeplabv3 version and put your checkpoints folder, then replace the absolute path in the code ./models/encoder.py.
  3. Run ./script/train.sh

Inference

Run ./script/test.sh

Acknowledgement

Our code is based the works: SSL-ALPNet, ADNet and QNet

Citation

@inproceedings{zhu2023few,
  title={Few-Shot Medical Image Segmentation via a Region-Enhanced Prototypical Transformer},
  author={Zhu, Yazhou and Wang, Shidong and Xin, Tong and Zhang, Haofeng},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={271--280},
  year={2023},
  organization={Springer}
}

rpt's People

Contributors

yazhouzhu19 avatar

Stargazers

PrimeBo avatar huihui1999 avatar persist avatar  avatar Dai X avatar  avatar zzz avatar Jie ZHAO avatar  avatar BenjaminLi avatar Shi Yukun avatar Vignesh Venkatesh avatar  avatar  avatar Lingting Zhu avatar  avatar yfxie avatar Spider Man avatar Nastu Ho avatar  avatar AndyBear avatar Xiaoqi_ avatar Hongliang Zhang avatar  avatar 鹤城北斗 avatar  avatar Larry avatar  avatar Zhonghao Yan avatar Mingqian Li avatar Ledah.Eholo.Gellers avatar Vinson avatar An-zhi WANG avatar taroball avatar Anonyme233 avatar Yuting Lin avatar

Watchers

Kostas Georgiou avatar

rpt's Issues

Experimental precision issues.

I reproduced your code and found that the experimental accuracy is far from the paper, I would like to know how did you calculate the final accuracy? Even the experimental results I got in the 50% fold cross validation and using the best supported image are 3 points off from the article results.

The loss did not decrease during training, please check the code

Hello author, thank you very much for your excellent work. But I obviously get an error based on the drop in loss when your code runs, can you check the code? There are also papers that mention that you used cross-entropy loss, boundary loss, DICE loss, and your code lacks DICE loss, and the boundary loss has not decreased during training. Looking forward to your code update.

BUG in class TrainDataset(Dataset)

Hello author. I found a bug in the datasets.py file with settings that don't match the description in the paper.

The __getitem__ method in class TrainDataset, which should get the index of the slices of this patient that contain the exclude_label in case exclude_label is passed in, has the following original code:

exclude_idx = np.full(gt.shape[0], True, dtype=bool)
for i in range(len(self.exclude_label)):
    exclude_idx = exclude_idx & (np.sum(gt == self.exclude_label[i], axis=(1, 2)) > 0)
    print(f'{i}_exclude_idx: {idx[exclude_idx]}')
exclude_idx = idx[exclude_idx]

Its does not achieve the above effect. If my exclude_label is [1, 2, 3, 4], then the exclude_idx I get in patient 38 and the sli_idx I get later is shown below:
image

Here is the code I corrected(Probably?):

exclude_idx = np.full(gt.shape[0], False, dtype=bool)
for i in range(len(self.exclude_label)):
    exclude_idx = exclude_idx | (np.sum(gt == self.exclude_label[i], axis=(1, 2)) > 0)
exclude_idx = idx[exclude_idx]

The result is as follows:
image

how to see the predicted results

Hello. I am having a problem visualizing the predicted 3D images and comparing them with ground truth as you have shown in your paper. Could you provide instructions or resources on how to accomplish this?

How to get the paper's Dice?

Dear author,
Thanks for you read this issue. When I use your code to train and test on database SABS, I found that I just can't reach the Dice Scores shown in your paper. I hope you can help me check my Config, which is laid below:
image
image

when I use the Experiment Setup as your paper said (30K iterations in total , decay of 0.8 every 1000 iterations, etc) to train this model, the test scores is far from the paper. It just hit 6-14% Dice scores. Besides, I noticed that the loss scores hardly decrease under your paper's Config. All the experiments are trained under NVIDIA RTX3090, but the training time only took up 5 hours.

visualization of the other datasets

Hello. Thank you so much for providing the code for visualization. However, since the visualization code is for CMR, I am unable to run the code as I still haven't received the dataset from the MS-CMR Challenge. Moreover, I am unable to make changes to make the code work for CHAOST2 or SABS.

So I hope you will provide the visualization code for the other datasets.

How to visualize the results?

Dear author,
Thanks for your attention. I have got the predicted results of SABS-CT datasets through your open-up codes, but there is still something disturbed me, which is the results cannot overlay on the original images directly. This problem shows as follow:
image
image

When I use the MRIcroGL, the normalized CT images don't match the final results.

About the process of the data

Thank you for your excellent work at first!
Regarding the data processing, I have a question: in section 2.1 3-dimensional pixel clustering is mentioned, and in section 3 you say that all 3D scans are reformatted into 2D axial and 2D short-axis slices.
So the last inputs to the model are 2D slices yes? May I ask how you performed the 3D clustering?
I'm looking forward to your reply!

About Settings 1 and 2

Hello author, thank you very much for your excellent work.

I have a question about the code: What are the parameters in the code for setting 1 and setting 2? I see this code in the config.py exclude_label = [1,2,3,4] # None, for not excluding test labels; Setting 1: None, Setting 2: True, if I want to run setting 1, do I need to change it to None? What other parameters are there that I need to modify?

Or do I only need to modify the two parameters of EXCLUDE_LABEL=None TEST=1234 in the train.sh to set 1 and 2?

Training and testing problems on SABS dataset

Dear author, when I conducted the SABS dataset experiment with your code and the pre-treated SABS hypervoxels you provided, the test results are SPLEEN, RK, LK and GALLBLADDER, but not SPLEEN, RK, LK and Liver as mentioned in this paper. Besides, GALLBLADDER has a very low DICE score on organ segmentation. What's the matter? Is the dataset == 'SABS' part of def get_label_names(dataset) in dataset_specifics.py incorrect (as shown in the figure below)? Could you please modify it?
image

Visualize the results in MRI and CT datasets

May I ask whether the author can open source the visual code of CT and MRI data sets? Because after executing your code, I found that the positions of organs in CT and MRI are different from Q-Net and AD-Net, etc.

I am looking forward to your reply

CHAOST setting1 task

Thank you for your excellent work, I had a hard time reproducing the optimal results in your article for the CHAOST setting1 task while training the code. Can you give me some more detailed hyperparameter settings, thank you very much!

关于数据可视化的问题

您好,在可视化文件夹中的mask_generation.py文件,cmr_lvbp_pre_show、cmr_img_lvmyo_show、cmr_img_rv_show均是为定义的变量,无法进行保存。
我在尝试将预测结果参考标签和原始图像保存为切片时,发现标签和原始图像形状为13256256,预测结果为11256256,请问如何解决这个问题?谢谢

About the new code

Dear author,

I noticed that your recently updated code has some parameters that are different from those of the paper, such as self.fg_num = 10 # number of foreground partitions to self.fg_num = 100 # number of foreground partitions. And the loss function has changed too.

Could you please tell me why there are these changes. Thanks!

Where your paper

Thank you very much for your wonderful work, may I ask where I can read your paper?

About the code section in the paper compared to other models

Dear author, thank you very much for your outstanding contribution. May I ask you about the comparison experiments with other models in your paper? For the two papers CRAPNet (WACV'23) and SR&CL (MICCAI'22), I did not find the source code published by the author. May I ask whether the author published the source code but I did not find it? Or did you reproduce it yourself from the original paper? (Is it convenient for you to publish the source code for these two models?)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.