Giter VIP home page Giter VIP logo

nmed2024's Introduction

AI-based differential diagnosis of dementia etiologies on multimodal data

This work is published in Nature Medicine (https://doi.org/10.1038/s41591-024-03118-z).

Introduction

This repository contains the implementation of a deep learning framework for the differential diagnosis of dementia etiologies using multi-modal data. Using data from $9$ distinct cohorts totalling $51,269$ participants, we developed an algorithmic framework, utilizing transformers and self-supervised learning, to execute differential diagnoses of dementia. This model classifies individuals into one or more of thirteen meticulously curated diagnostic categories, each aligning closely with real-world clinical requirements. These categories span the entire spectrum of cognitive conditions, from normal cognition (NC), mild cognitive impairment (MCI) to dementia (DE), and further include $10$ dementia types.

Figure 1: Data, model architecture, and modeling strategy. (a) Our model for differential dementia diagnosis was developed using diverse data modalities, including individual-level demographics, health history, neurological testing, physical/neurological exams, and multi-sequence MRI scans. These data sources whenever available were aggregated from nine independent cohorts: 4RTNI, ADNI, AIBL, FHS, LBDSU, NACC, NIFD, OASIS, and PPMI. For model training, we merged data from NACC, AIBL, PPMI, NIFD, LBDSU, OASIS and 4RTNI. We employed a subset of the NACC dataset for internal testing. For external validation, we utilized the ADNI and FHS cohorts. (b) A transformer served as the scaffold for the model. Each feature was processed into a fixed-length vector using a modality-specific embedding strategy and fed into the transformer as input. A linear layer was used to connect the transformer with the output prediction layer. (c) A distinct portion of the NACC dataset was randomly selected to enable a comparative analysis of the model’s performance against practicing neurologists. Furthermore, we conducted a direct comparison between the model and a team of practicing neuroradiologists using a random sample of cases with confirmed dementia from the NACC testing cohort. For both these evaluations, the model and clinicians had access to the same set of multimodal data. Finally, we assessed the model’s predictions by comparing them with pathology grades available from the NACC, ADNI, and FHS cohorts.

Prerequisites

To setup the adrd package, run the following in the root of the repository:

pip install git+https://github.com/vkola-lab/nmed2024.git

The tool was developed using the following dependencies:

  1. Python (3.11.7 or greater)
  2. PyTorch (2.1 or greater).
  3. TorchIO (0.15 or greater).
  4. MONAI (1.1 or greater).
  5. NumPy (1.24 or greater).
  6. tqdm (4.62 or greater).
  7. pandas (1.5.3 or greater).
  8. nibabel (5.0 or greater).
  9. matplotlib (3.7.2 or greater).
  10. shap (0.43 or greater).
  11. scikit-learn (1.2.2 or greater).
  12. scipy (1.10 or greater).

Installation

You can clone this repository using the following command:

git clone https://github.com/vkola-lab/nmed2024.git

Training

The training process consists of two stages:

1. Imaging feature extraction

All code related to training the imaging model with self-supervised learning is under ./dev/ssl_mri/.

Note: we used skull stripped MRIs to get our image embeddings. We have provided the script for skull stripping using the publicly available SynthStrip tool [2]. The code is provided under dev/skullstrip.sh.

a) Training the imaging feature extractor

We trained started from the self-supervised pre-trained weights of the Swin UNETR encoder (CVPR paper [1]) which can be downloaded from this link. The checkpoint should be saved under ./dev/ssl_mri/pretrained_models/.

To finetune the pre-trained Swin UNETR on your own data, run the following commands:

cd dev/ssl_mri/
bash scripts/run_swinunetr.sh

The code can run in a multi-GPU setting by setting --nproc_per_node to the appropriate number of available GPUs.

b) Saving the MRI embeddings

Once a finetuned checkpoint of the imaging model is saved, navigate to the repository's root directory and run dev/train.sh with the following changes in flag values:

img_net="SwinUNETR"
img_mode=2 # loads the imgnet, generates embeddings out of the MRIs input to the network, and saves them.

2. Training the backbone transformer

Once image embeddings are saved, we train the backbone transformer on the multi-modal data. Create a configuration file similar to default_conf_new.toml, categorizing each feature as numerical, categorical or imaging. Please add the saved image embedding paths to your data file as another column and set the type of this feature as imaging in the configuration file. Navigate to the repository's root directory and run dev/train.sh with the following changes in flag values:

img_net="SwinUNETREMB" 
img_mode=1 # loads MRI embeddings and not the imgnet.

To train the model without imaging, please use the following flag values:

img_net="NonImg" 
img_mode=-1

Evaluation

All evaluation reports, including AUC-ROC curves, AUC-PR curves, confusion matrices, and detailed classification reports, were generated using the script dev/visualization_utils.py.

Demo

To make our deep learning framework for differential dementia diagnosis more accessible and user-friendly, we have hosted it on Huggingface Space. This interactive demo allows users to experience the power and efficiency of our model in real-time, providing an intuitive interface for uploading diagnostic information and receiving diagnostic predictions. Check out our Huggingface demo https://huggingface.co/spaces/vkola-lab/nmed2024 to see our model in action and explore its potential.

References

[1] Tang, Y., Yang, D., Li, W., Roth, H.R., Landman, B., Xu, D., Nath, V. and Hatamizadeh, A., 2022. Self-supervised pre-training of swin transformers for 3d medical image analysis. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 20730-20740).

[2] Hoopes, A., Mora, J.S., Dalca, A.V., Fischl, B. and Hoffmann, M., 2022. SynthStrip: Skull-stripping for any brain image. NeuroImage, 260, p.119474.

Citation

@article{xue2024ai,
  title={AI-based differential diagnosis of dementia etiologies on multimodal data},
  author={Xue, Chonghua and Kowshik, Sahana S and Lteif, Diala and Puducheri, Shreyas and Jasodanand, Varuna H and Zhou, Olivia T and Walia, Anika S and Guney, Osman B and Zhang, J Diana and Pham, Serena T and others},
  journal={Nature Medicine},
  pages={1--13},
  year={2024},
  publisher={Nature Publishing Group US New York}
}

nmed2024's People

Contributors

sahanakowshik avatar vkola avatar osmanberke avatar

Stargazers

 avatar  avatar DFan avatar  avatar Nancy-Zhang avatar Kaouther MOUHEB avatar pomodoren avatar Jianpan Huang avatar  avatar glacier16127 avatar  avatar hujili avatar Meijia Li avatar Michelle Barboure avatar  avatar Jacob (Jiajian) Ma avatar  avatar sharon avatar Jun Li avatar Lin Jiawei avatar  avatar  avatar Maximilian Schirm avatar  avatar  avatar Yilei Wu avatar  avatar Jacob Dodd avatar  avatar  avatar Tang Cheng avatar  avatar Siyi Du avatar Min Hee Kim avatar  avatar Sophie Martin avatar  avatar Francesco Pisu avatar Chonghua Xue avatar

Watchers

Kaouther MOUHEB avatar  avatar

nmed2024's Issues

deIT missing

Hi!
Thank you a lot for the impressive work! I really enjoyed reading your paper.
I am trying to run the code on my data but it keeps complaining with this error ModuleNotFoundError: No module named 'models.DeIT'. When I check the code, it seems that indeed this folder does not exist. Could you please help solving this issue? Thanks in advance

A few questions

Hello, I am very happy to learn about your work. I have encountered several questions during the learning process: Which part of the code should I refer to for the implementation process of Hugging Face without using image data? And is there any reference code for the experimental process of the catboost model mentioned in the article?

The Image Voxel Normalisation

Before the MRI is sent into the SwinUnetr, I wonder that whether the image voxel value is normalised to [-1, 1], [0, 1] or else ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.