Giter VIP home page Giter VIP logo

xray-feature-disentanglement's Introduction

Deep learning models for COVID-19 chest x-ray classification: Preventing shortcut learning using feature disentanglement

Disclaimer: This repository is provided for research and development use only. The models described in this repo are not intended for use in clinical decision-making or for any other clinical use and the performance of these models for clinical use has not been established.

This README describes how to reproduce models/some experiments from the preprint "Deep learning models for COVID-19 chest x-ray classification: Preventing shortcut learning using feature disentanglement". Note: the "CC-CCII" dataset is proprietary so we cannot include it, however this repository reproduces all results that only depend on the COVIDx dataset.

A rough sketch of how the repository is setup:

  • preprocess.py is run with a path to a "metadata.csv" file that contains pointers to unprocessed chest x-ray images (in DICOM, jpeg, png format) and their disease labels. This script will create masked and unmasked copies of each input image that have been resized/cropped to a size of 224x224 in a desired location with an accompanying "metadata_preprocessed.csv" file.
  • create_embeddings.py is run with a path to a preprocessed dataset (a "metadata_preprocessed.csv" file). This script will create embeddings using some existing pre-trained model (currently, the three models we list in the paper), that can be used to quickly train a classifier. We call the output of this step an embedded dataset.
  • train.py uses an embedded dataset and corresponding domain/task labels to train a classifier. This can be done with or without "feature disentanglement".
  • evaluate.py uses a classifier and a preprocessed dataset to generate embeddings for new data. Note that these embeddings are not the same as those generated by one of the existing pre-trained models (e.g. torchxrayvision), but are a result of using a pre-trained model and classifier (see the paper for more details about this distinction).

Setup

The following commands should create a conda environment with the necessary requirements for reproducing our results. Note: our conda environment assumes that you have a CUDA Version of 11, you may need to adjust this to match your system

conda env create -f environment.yml
conda activate xray-feature-disentanglement

Downloading the COVID-Net model

To get the COVID-Net pre-trained model weights:

  • Go here and click on the "COVIDNet-CXR Large" large link. This should take you to a Google Drive hosted by the authors of COVID-Net.
  • Download the following files into data/pretrained_models/COVIDNet-CXR_Large:
    • checkpoint
    • model-8485.data-00000-of-00001
    • model-8485.index
    • model.meta

Downloading the COVIDx dataset

We use two CXR datasets in the accompanying paper: the open COVIDx dataset, and a private dataset from the CC-CCII. We cannot include the CC-CCII dataset, so this repository serves to reproduce the experiments that depend on only the COVIDx dataset.

To download the COVIDx dataset and create a "metadata.csv" file to use throughout the pipeline:

Editing utils.py

In utils.py you should set BASE_DIR to point to the full path of where you have cloned this repository.

Reproducing results with the COVIDx dataset

Data preprocessing

We assume that datasets are defined as a list of filenames with corresponding metadata (e.g "label" and "patient_id") defined in, what we call, "metadata.csv" files. The preprocess.py will consume a "metadata.csv" file, standardize the dimensions of each image, apply lung masking to each image, etc. and save the results to a directory.

Run the following using the data/metadata_covidx.csv file created by notebooks/Preprocessing - COVIDx - combine splits and create metadata file.ipynb:

mkdir -p datasets/covidx/
MKL_THREADING_LAYER=GNU python preprocess.py --input_fn data/metadata_covidx.csv --output_dir datasets/covidx/ --disable_flip_preprocessing --overwrite

Generating embeddings from external pre-trained models

After the datasets have been preprocessed we run create_embeddings.py to generate embeddings from different feature extractor models, which we will then use to train a classifier.

# generate embeddings for masked images
python create_embeddings.py --input covidx --name covidx --model xrv --mask masked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model densenet --mask masked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model covidnet --mask masked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model histogram --mask masked --output_dir datasets/embeddings/

# generate embeddings for unmasked images
python create_embeddings.py --input covidx --name covidx --model xrv --mask unmasked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model densenet --mask unmasked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model covidnet --mask unmasked --output_dir datasets/embeddings/
python create_embeddings.py --input covidx --name covidx --model histogram --mask unmasked --output_dir datasets/embeddings/

Generating results

We provide the way to reproduce 3 sets of results from the paper:

  • The first two columns of Table 2
  • The first two columns of Table 3
  • (roughly) The UMAP visualization

The first set of results is generated by python generate_table2_results.py.

The second set of results is generated by python run_main_experiments.py followed by python generate_table3_results.py. NOTE: running the experiments will take a long time and require GPU resources.

The third set of results is generated by bash run_umap_experiments.sh and the notebook notebooks/Results - Generate UMAPs.ipynb.

Miscellaneous information

We use the torchxrayvision project from here at commit b274a7a32c462faff6df8cde711498d34f1acc36 on the master branch.

We use the COVID-Net project from here at commit d6f3552f44f1af99981dbc960ee46ea3bceecd61 on the master branch. Specifically, we use the pre-trained model "COVIDNet-CXR Large", and the dataset creation notebook (copied to notebooks/Preprocessing- Create COVID-Net dataset.ipynb).

We use the lungVAE project from here at commit 52b44df82a351706db2f575758ea3b8452389998 on the master branch. We also make the following small changes (see lungVAE/):

diff --git a/predict.py b/predict.py
index 163a775..3876066 100644
--- a/predict.py
+++ b/predict.py
@@ -124,7 +124,7 @@ t = time.strftime("%Y%m%d_%H_%M")
 if args.saveLoc is '':
        save_dir = args.data+'pred_'+t+'/'
 else:
-       save_dir = args.saveLoc+'pred_'+t+'/'
+       save_dir = args.saveLoc + '/'
 if not os.path.exists(save_dir):
        os.mkdir(save_dir)

@@ -134,7 +134,7 @@ print("Model "+args.model.split('/')[-1]+" Number of parameters:%d"%(nParam))
 if args.dicom:
        filetype = 'DCM'
 else:
-       filetype= 'png'
+       filetype= 'jpg'

 files = list(set(glob(args.data+'*.'+filetype)) \
                        - set(glob(args.data+'*_mask*.'+filetype)) \
@@ -144,11 +144,14 @@ files = sorted(files)
 for fIdx in range(len(files)):
                f = files[fIdx]
                fName = f.split('/')[-1]
-               img, roi, h, w, hLoc, wLoc, imH, imW = loadDCM(f,
-                                                                                                       no_preprocess=args.no_preprocess,
-                                                                                                       dicom=args.dicom)
+               img, roi, h, w, hLoc, wLoc, imH, imW = loadDCM(
+                       f,
+                       no_preprocess=args.no_preprocess,
+                       dicom=args.dicom
+               )
                img = img.to(device)
-               _,mask = net(img)
+               _, mask = net(img)
+               mask = mask.cpu()
                mask = torch.sigmoid(mask*roi)
                f = save_dir+fName.replace('.'+filetype,'_mask.png')

License

This project is licensed under the MIT License.

xray-feature-disentanglement's People

Contributors

calebrob6 avatar microsoft-github-operations[bot] avatar microsoftopensource avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.