Giter VIP home page Giter VIP logo

cortexode's Introduction

CortexODE: Learning Cortical Surface Reconstruction by Neural ODEs

This is the official PyTorch implementation of the paper:

CortexODE: Learning Cortical Surface Reconstruction by Neural ODEs

Qiang Ma, Liu Li, Emma C. Robinson, Bernhard Kainz, Daniel Rueckert, Amir Alansary

Overview

CortexODE leverages Neural Ordinary Different Equations (ODEs) to learn diffeomorphic flows for cortical surface reconstruction. The trajectories of the points on the surface are modeled as ODEs, where the derivatives of their coordinates are parameterized via a learnable Lipschitz-continuous deformation network. This provides theoretical guarantees for the prevention of surface self-intersections.

CortexODE is integrated to an automatic learning-based pipeline. The pipeline utilizes a 3D U-Net to predict a white matter (WM) segmentation from brain MRI scans, and further generates a signed distance function (SDF) that represents an initial surface. Fast topology correction is used to guarantee the genus-0 topology. After isosurface extraction, two deformation networks are trained to deform the initial surface to white matter and pial surfaces respectively.

Installation

This implementation mainly relies on PyTorch, PyTorch3D and torchdiffeq libraries. The dependencies can installed by running the following codes:

conda create --name cortexode
source activate cortexode
pip install -r requirements.txt

In addition, the PyTorch3D library (v0.4.0) should be installed manually.

Dataset

CortexODE is validated on the ADNI, HCP Young Adult and dHCP (3rd Release) dataset. The data splits for training/validation/testing are given in ./data/split. For the ADNI and HCP datasets, the data should be in FreeSurfer format. We have provided an example data in ./data/adni/test/subject_1.

For a new dataset, please revise ./data/preprocess.py to make sure the cortical surfaces match the MRI volume. For the data aligned to MNI-152 space, we recommend to use data_name='adni' with minor modification.

Evaluation

We have provided the pretrained models in ./ckpts/pretrained for all datasets. To predict the cortical surfaces for the example ADNI data using pretrained models, please run

python eval.py --test_type='pred' --data_dir='./data/adni/test/' --model_dir='./ckpts/pretrained/adni/' --result_dir='./ckpts/experiment_1/result/' --data_name='adni' --surf_hemi='lh' --tag='pretrained' --solver='euler' --step_size=0.1 --device='gpu'

--data_dir is the path of the dataset. --model_dir is the path of saved models. --result_dir is the path to save the predicted surfaces (obj / stl / FreeSurfer format). Please refer to config.py for detailed configurations. To evaluate the ASSD, HD and self-intersections, please use

python eval.py --test_type='eval' --data_dir='./data/adni/test/' --model_dir='./ckpts/pretrained/adni/' --data_name='adni' --surf_hemi='lh' --tag='pretrained' --solver='euler' --step_size=0.1 --device='gpu'

The torch-mesh-isect package should be installed to detect the mesh self-intersections.

Training

The training of CortexODE models consists of segmentation, initial surface generation, and cortical surface reconstruction. The following instruction shows how to train CortexODE models on the ADNI dataset on left brain hemisphere.

Segmentation

For WM segmentation, please run the following code to train a 3D U-Net for 200 epochs:

python train.py --train_type='seg' --data_dir='./data/adni/' --model_dir='./ckpts/experiment_1/model/' --data_name='adni' --n_epoch=200 --tag='exp1' --device='gpu'

where --model_dir is the path to save the model checkpoints. experiment_1 and --tag='exp1' are the identity of the experiment.

Initial Surface Generation

After training the WM segmentation, select the model with the best validation Dice score and run the following codes to create initial surfaces for both training and validation sets:

python eval.py --test_type='init' --data_dir='./data/adni/train/' --model_dir='./ckpts/experiment_1/model/' --init_dir='./ckpts/experiment_1/init/train/' --data_name='adni' --surf_hemi='lh' --tag='exp1' --device='gpu'
python eval.py --test_type='init' --data_dir='./data/adni/valid/' --model_dir='./ckpts/experiment_1/model/'  --init_dir='./ckpts/experiment_1/init/valid/' --data_name='adni' --surf_hemi='lh' --tag='exp1' --device='gpu'

where --init_dir is the path to save the initial surfaces.

Cortical Surface Reconstruction

We train two deformation networks for WM and pial surface reconstruction. We use the adjoint sensitivity method proposed in Neural ODEs to train our models with constant GPU memory cost. To train CortexODE for WM surface reconstruction using Euler solver with step size h=0.1, you can use:

python train.py --train_type='surf' --data_dir='./data/adni/' --model_dir='./ckpts/experiment_1/model/' --init_dir='./ckpts/experiment_1/init/' --data_name='adni'  --surf_hemi='lh' --surf_type='wm' --n_epochs=400 --n_samples=150000 --tag='exp1' --solver='euler' --step_size=0.1 --device='gpu'

where --init_dir is the path of the input initial surfaces, and --n_samples is the number of randomly sampled points to compute the Chamfer loss. For pial surface reconstruction, please run:

python train.py --train_type='surf' --data_dir='./data/adni/' --model_dir='./ckpts/experiment_1/model/' --data_name='adni'  --surf_hemi='lh' --surf_type='gm' --n_epochs=400 --tag='exp1' --solver='euler' --step_size=0.1 --device='gpu'

Note that we use lr=1e-4 as the learning rate to reduce the training time. Therefore, the curve of the validation error can highly oscillate. Please select the models with the best performance on the validation.

cortexode's People

Contributors

m-qiang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.