Giter VIP home page Giter VIP logo

ba_diff's Introduction

Semi-Supervised Diffusion Model for Brain Age Prediction

This repository contains the code required to train a Semi-Supervised Diffusion Model for Brain Age Prediction It is an adaptation of the original Diffusion Autoencoder repo found at:

[Diffusion Autoencoder Repo]

Which was introduced in the paper:

Diffusion Autoencoders: Toward a Meaningful and Decodable Representation
K. Preechakul, N. Chatthee, S. Wizadwongsa, S. Suwajanakorn 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).

The data included in the Workshop paper is not used in this Repo because although alot of the data is available from public datasets, individuals must apply to the relevant consortiums in order to have access. We have provided much detail concerning the datasets in the Appendix, such that it can be requested if one wishees.


Training a model

  1. Download the files from the repo in a fashion that replicates the directory structure of this repo.

  2. Make a virtual environment either natively in python by running:

    pip install virtualenv 
    
    virtualenv ss_diffage_env
    

    Or in conda by running:

    conda create -n _env
    
  3. Activate that environment Native Python:

    source ./ss_diffage_env/bin/activate 
    

    Conda:

    conda activate ss_diffage_env
    
  4. Install all of the neccessary dependencies by running:

    pip install -r requirement.txt
    
  5. Next ammend the file dataset.py such that it loads in your data accordingly. An example dataloader can be found in the file.

  6. Then config.py must should be ammended such that the hyperparameters used meet your specifications. These arguments exist on the TrainConfig dataclass which starts on line 25. An argument which are particularly of note is:

  • load_in : This specifies how long training should happen before the age prediction kicks in.
  • batch_size : The size of batches

The make_dataset method on the TrainConfig class should also be ammended to load your dataset accordingly. Again examples have been left here as a guide.

  1. Following this templates.py needs to be modified according to your model, and data specificiation. Changes to the conf.net_ch_mult, will make your model smaller of bigger for example.

  2. Then train.py needs to be ammended such that it calls on the configuration for your dataset/particular model. An example has been left there as well.

After following the above steps, the model will be ready to train with your specifications and dataset.

After following the above steps, the model will be ready to train with your specifications and dataset.

  1. Run:
    python3 train.py
    

It is advised that you also inspect the expeiriment.py file as this is the location of the pytorch_lightning class, LitModel, which further defines the training specifications. Methods on this class which should particularly be inspected are:

  • training_step : (line 420) modifications should be made to ensure that the data is loaded in each step appropriately.
  • training_epoch_end : (line 243) modifications should be made to log metrics at the end of each epoch.
  • ModelCheckpoint : (line 1008) modifications should be made to configure checkpointing according to your needs.

The trainer also includes logging of images and the MSE loss as well, so use of the tensorboard is advised. This can be done by running the following command in a terminal with the aformentioned environment active: tensorboard --logdir=checkpoints This should open up the tensorboard in a localhost.


Prediction

If you have downloaded the relevant datasets and which to make predictions yourself then you can run the following script. This will work provided that you follow the pre-processing steps outlined in the Appendix of the paper, and have saved the files as niftis.

  1. Run the prediction script:
    python3 predict.py --data_dir <directory with niftis> --slice <the slice number that you will predict on> --ext <file extension e.g: .nii or .nii.gz> --checkpoint <path to model checkpoint>
    

The above script will load model weights and save the predictions to a CSV called: predicted_ages.csv. This file will have two columns: ID and predicted age.

ba_diff's People

Contributors

a-ijishakin avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.