Giter VIP home page Giter VIP logo

ncomms2022's Introduction

Multimodal deep learning for Alzheimer's disease dementia assessment

This work is published in Nature Communications (https://doi.org/10.1038/s41467-022-31037-5).

Introduction

This repository contains the implementation of a deep learning framework that accomplishes 2 diagnostic steps to identify persons with normal cognition (NC), mild cognitive impairment (MCI), Alzheimer’s disease (AD) dementia, and dementia due to other etiologies (nADD).

We demonstrated that the framework compares favorably with the diagnostic performance of neurologists and neuroradiologists. To interpret the model, we conducted SHAP (SHapley Additive exPlanations) analysis on brain MRI and other features to reveal disease-specific patterns that correspond with expert-driven ratings and neuropathological findings.

Prerequisites

The tool was developed using the following dependencies:

  1. PyTorch (1.10 or greater).
  2. NumPy (1.19 or greater).
  3. tqdm (4.31 or greater).
  4. nibabel (3.2 or greater).
  5. matplotlib (3.3 or greater).
  6. scikit-learn (0.23 or greater).
  7. scipy (1.5.4 or greater).
  8. shap (0.37 or greater).
  9. xgboost (1.3.3 or greater).
  10. catboost (0.24 or greater).

Please note that the dependencies may require Python 3.6 or greater. It is recommended to install and maintain all packages using conda or pip. For installation of GPU accelerated PyTorch, additional effort may be required. Please check the official websites of PyTorch and CUDA for detailed instructions.

Installation

Recommend to only clone the last version to avoid getting all commits during the development stage.

git clone --depth 1 https://github.com/vkola-lab/ncomms2022.git

Documentation

Train a model

1. Train a CNN model

The model_wrappers.py contains the interfaces for initializing, training, testing, saving, loading the model as well as creating SHAP interpretable heatmaps. See below for a basic example usage.

from model_wrappers import Multask_Wrapper
from utils import read_json

model = Multask_Wrapper(
    tasks=['ADD', 'COG'],                            # a list of tasks to predict
    device=1,                                        # GPU device to use
    main_config=read_json('config.json'),            # general configuration for the experiment  
    task_config=read_json('task_config.json'),       # task specific configurations
    seed=1000
)                                       
model.train()                                                            
thres = model.get_optimal_thres()                    # get optimal threshold using validation dataset
model.gen_score(['test'], thres)                     # apply optimal threshold on test dataset and cache predictions

2. Train a Fusion/NonImg model

The interface for training a fusion model or non-imaging model is similar to that of the CNN model. See below for a basic example usage.

from nonImg_model_wrappers import NonImg_Model_Wrapper, Fusion_Model_Wrapper
from utils import read_json

model = NonImg_Model_Wrapper(
    tasks=['ADD', 'COG'],                            # a list of tasks to predict
    main_config=read_json('config.json'),            # general configuration for the experiment  
    task_config=read_json('task_config.json'),       # task specific configurations
    seed=1000
)                                       
model.train()                                                            
thres = model.get_optimal_thres()                    # get optimal threshold using validation dataset
model.gen_score(['test'], thres)                     # apply optimal threshold on test dataset and cache predictions

model = Fusion_Model_Wrapper(
    tasks=['ADD', 'COG'],                            # a list of tasks to predict
    main_config=read_json('config.json'),            # general configuration for the experiment  
    task_config=read_json('task_config.json'),       # task specific configurations
    seed=1000
)                                       
model.train()                                                            
thres = model.get_optimal_thres()                    # get optimal threshold using validation dataset
model.gen_score(['test'], thres)                     # apply optimal threshold on test dataset and cache predictions

Note:

  1. The Multask_Wrapper class defines a generic multi-task deep learning model which has shared convolutional blocks for feature extracting and standalone task-specific MLPs for classification or regresssion.
  2. Model's weights will be stored under the subfolder of checkpoint_dir that corresponds to each experiment.
  3. Model's predictions along with labels will be saved as a csv file under the subfolder of tb_log corresponding to each experiment so that performance evaluation solely depends on the outputed csv without doing the inference again.

Evaluate a model

Since the gen_score method has already saved the raw predictions in csv, the evaluation pipeline just needs to look for those information from the corresponding experimental folders under tb_log. Mean, std or 95% confidence intervals are estimated using multiple independent experiments, for instance, from five-fold cross validation.

1. ROC/PR curves

from performance_eval import generate_roc, generate_pr
generate_roc(
    csv_files,              # list[csvfiles] produced from "gen_score". Mean and std are estimated from cross-validation experiments  
    positive_label, 
    color, 
    out_file
)
generate_pr(
    csv_files,              # list[csvfiles] produced from "gen_score". Mean and std are estimated from cross-validation experiments 
    positive_label, 
    color, 
    out_file
)

2. Performance table

The performance table contain accuracy, sensitivity, specificity, F-1, MCC for different tasks.

from performance_eval import perform_table
perform_table(
    csv_files,              # list[csvfiles] produced from "gen_score". Mean and std are estimated from cross-validation experiments
    output_name             # any name for the output csv file that contains metric information
)

3. Confusion matrix

from performance_eval import crossValid_cm(csv_files, stage)
crossValid_cm(
    csv_files,              # list[csvfiles] produced from "gen_score". Mean and std are estimated from cross-validation experiments
    stage                   # if stage='test', confusion matrix for the test dataset will be generated
)

4. Full evaluation package

This full package compiled ROC/PR, performance table and confusion matrix all together.

from performance_eval import whole_eval_package
whole_eval_package(model_name, 'test')         # evaluate on NACC testing set
whole_eval_package(model_name, 'OASIS')        # evaluate on OASIS dataset

Interpret models with SHAP

1. Interpret the CNN model (MRI saliency map)

The shap_mid method will load pretrained weights first and then generates the SHAP interpretable saliency map for a specific middle layer over all instances.

model = Multask_Wrapper(   # instantiate an already trained model
    tasks=['ADD', 'COG'],                            
    device=1,                                        
    main_config=read_json('config.json'),       
    task_config=read_json('task_config.json'),       
    seed=1000
)                                                                                                
model.shap_mid(
    task_idx=0,            # if task_idx == 0, the shap analysis will be about the ADD task (tasks[task_idx])
    path='somewhere/',     # where you want to save the generated shap numpy array
    file='test.csv',       # shap will be genareted on each case from this file
    layer='block2conv'     # which layer of the model that you want to interpret 
)                                       

For more details, please see the SHAP

2. Interpret the Fusion/NonIng model (feature importance)

The shap method will initialize corresponding SHAP explainer for various models, including XGBoost, CatBoost, Random Forest, Decision Tree, Support Vector Machine, Nearest Neighbor, Multi-layer Perceptron. See below for an example.

model = NonImg_Model_Wrapper(
    tasks=['ADD', 'COG'],                            # a list of tasks to predict
    main_config=read_json('config.json'),            # general configuration for the experiment  
    task_config=read_json('task_config.json'),       # task specific configurations
    seed=1000
)                                       
model.train()                                                            
thres = model.get_optimal_thres()                    # get optimal threshold using validation dataset
model.gen_score(['test'], thres)                     # apply optimal threshold on test dataset and cache predictions
shap_values, _ = model.shap("test_shap")             # get shap values for all features over instances from test dataset

Data visualization

Please find the scripts used for plotting from the FigureTable/ folder.

Data Preparation

To follow the data distribution policy from different study centers, we provided guidance on accessing and processing meta information instead of sharing the data within this repo. The meta data contains demographic information, medical history, neuropsychological tests, and functional questionaires. Please refer to our paper for a complete list of the features included.

1. Meta information

We collected and organized meta data from 8 cohorts in the folder structure as below:

lookupcsv
│
├── raw_tables         # inside raw_tables, you should save the directly-downloaded tables.  
│   ├── NACC_ALL       # within each folder, there is a readme file to guide the user to access and dowload data   
│   │   ├── readme.txt 
│   ├── ADNI
│   ├── OASIS
│   ├── AIBL
│   ├── FHS
│   └── ...
│
├── derived_tables     # inside raw_tables, you should save the directly-downloaded tables.
│   ├── NACC_ALL       # within each folder, there is a readme file to guide the user to run processing scripts that we provided 
│   │   ├── readme.txt
│   ├── ADNI
│   ├── OASIS
│   ├── AIBL
│   ├── FHS
│   └── ...
│
├── dataset_table      # this is where the final meta table is saved
│   ├── NACC_ALL
│   ├── ADNI
│   ├── OASIS
│   ├── AIBL
│   ├── FHS
│   └── ...
│
├── CrossValid         # concate meta tables from dataset_table/ and then split NACC into train, valid, test                  
│   ├── cross0         # different cross contains different split, see our paper for more details on how the split was done
│   │   ├── train.csv  
│   │   ├── valid.csv
│   │   ├── test.csv
│   │   ├── OASIS.csv
│   │   ├── exter_test.csv
│   ├── cross1              
│   ├── cross2             
│   ├── cross3            
│   └── cross4            
└── ...

To prepare for the meta data, (1) download data from offical data portals using the guidance from readme and save those in raw_tables (2) use the scripts provided in derived_tables to produce intermediate outcome (3) use the scripts provided in dataset_table to produce the final meta table using the information from both raw_tables and derived_tables (4) concate and split data for cross-validation

2. MRI processing

The pipeline for MRI processing is available in MRI_process/pipeline.sh. There are 4 sample de-identified and processed MRI scans in demo/mri/.

DEMO

We also provide a demo script (demo_inference.py) to demonstrate how to generate inference on other data instances using pretrained CNN weights.

python demo_inference.py

Running the command above will produce a csv table under demo/ folder which contains the model's predictions on those 4 MRI scans from demo/mri/.

ncomms2022's People

Contributors

shangranq avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

ncomms2022's Issues

Question about demo

Hi, I ran demo_inference.py and got these results. They are not consistent with the given label. Is this right?

filename ADD_score ADD_pred COG_score COG_pred COG ADD
demo1.npy 0.99999815 1 1.0072571 1 1 0
demo2.npy 0.007306823 0 1.2147483 1 0 0
demo3.npy 0.9996529 1 0.53027755 1 2 1
demo4.npy 0.99970007 1 0.7140738 1 2 0

Thanks a lot!

shap analysis issues

Hi, I just got ideas from your paper and code: i used my own model which accepts two modality image mri and pet to perform model_wrappers in your code. And I got nice AUC、PR curves, but when i started to do shap analysis, the shap_values array constains a lots of zero. This is my first time using shap analysis, so I want to know what causes such problems.
looking forward to your reply!

您好,受了您的论文和代码启发之后,我尝试自己构建的模型使用MRI与PET两种模态影像联合诊断疾病阶段,在ROC与PR结果方面表现很好,但是我在进行shap分析生成于FigureTable/NeuroPathRegions/shap_csvfiles中的csv文件包含了很多0值,这导致我在后续进行您代码中FigureTable/brainNetwork时出现了问题*,所以我想知道在模型表现优异的情况下,shap分析得到的结果不理想。
期待您的回复!

paper title

Hello,Can you tell me the title of your paper? I'm interested in your project.

installation issue from requirements.txt

After creating a conda environment, use pip install -r requirements.txt, but encountered many ERRORs: No matching distribution found for many libraries.

For instance,
ERROR: Could not find a version that satisfies the requirement libpng==1.6.37 (from -r requirements.txt (line 40)) (from versions: none)

Can you please update a new requirements.txt file to fix such issues? Thank you very much.

Need more information on NACC CSV Filed

I am trying to reproduce the results and I have the access with NACC dataset, but I am not able to find four CSV file which you have used in derived tables section.

Following are the four files which I am facing issue to get the hands on, please if you could share if you have alter or modify the data that you received from NACC.

  1. kolachalama12042020apet.csv
  2. kolachalama12042020csf.csv
  3. kolachalama12042020.csv
  4. kolachalama12042020mri.csv

Please if you could share more insights on dataset used, would be helpful.

Preprocessing pipeline

I have been trying to use your pipleine preprocessing on the OASIS3 dataset (Tw1).
But nothing seems to happen. May you please guide me on how to properly use it.
image

shap analysis issues

Hi, I just got ideas from your paper and code: i used my own model which accepts two modality image mri and pet to perform model_wrappers in your code. And I got nice AUC、PR curves, but when i started to do shap analysis, the shap_values array constains a lots of zero. This is my first time using shap analysis, so I want to know what causes such problems.
looking forward to your reply!

您好,受了您的论文和代码启发之后,我尝试自己构建的模型使用MRI与PET两种模态影像联合诊断疾病阶段,在ROC与PR结果方面表现很好,但是我在进行shap分析生成于FigureTable/NeuroPathRegions/shap_csvfiles中的csv文件包含了很多0值,这导致我在后续进行您代码中FigureTable/brainNetwork时出现了问题*,所以我想知道在模型表现优异的情况下,shap分析得到的结果不理想的原因是什么。
期待您的回复!

data processing

Hello, I would like to ask about the data processing part. Should pipeline.py, biasFieldCorrection.py, and segmentation.py be run sequentially? Why is segmentation (segmentation.py) necessary? The images I obtained by running these 3 Python scripts sequentially (picture1) seem to be different from the images provided in the demo (demo1.nii, picture2). I would appreciate it if I could get an answer. Thank you very much!
Snipaste_2024-04-26_01-19-03
Snipaste_2024-04-26_01-17-51

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.