Giter VIP home page Giter VIP logo

regnet's Introduction

RegNet

Introduction

In this work we propose a method to solve nonrigid image registration through a learning approach, instead of via iterative optimization of a predefined dissimilarity metric. We design a Convolutional Neural Network (CNN) architecture that, in contrast to all other work, directly estimates the displacement vector field (DVF) from a pair of input images. The proposed RegNet is trained using a large set of artificially generated DVFs, does not explicitly define a dissimilarity metric, and integrates image content at multiple scales to equip the network with contextual information. At testing time nonrigid registration is performed in a single shot, in contrast to current iterative methods.

Citation

[1] Hessam Sokooti, Bob de Vos, Floris Berendsen, Mohsen Ghafoorian, Sahar Yousefi, Boudewijn P.F. Lelieveldt, Ivana Išgum and Marius Staring, 2019. 3D Convolutional Neural Networks Image Registration Based on Efficient Supervised Learning from Artificial Deformations. arXiv preprint arXiv:1908.10235.

[2] Hessam Sokooti, Bob de Vos, Floris Berendsen, Boudewijn P.F. Lelieveldt, Ivana Išgum, and Marius Staring, 2017, September. Nonrigid image registration using multi-scale 3D convolutional neural networks. In International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 232-239). Springer, Cham.

1. Dependencies

  • Joblib : Running Python functions as pipeline jobs.
  • Matplotlib A plotting library for the Python programming language and its numerical mathematics extension NumPy.
  • NumPy : General purpose array-processing package.
  • SimpleITK : Simplified interface to the Insight Toolkit for image registration and segmentation.
  • SciPy : A Python-based ecosystem of open-source software for mathematics, science, and engineering.
  • TensorFlow v1.x : TensorFlow helps the tensors flow.
  • xmltodict : Python module that makes working with XML feel like you are working with JSON.

2. Running RegNet

RunRegNet3D.py. Please note that current RegNet only works with 3D images.

2.1 Data

All images are read and written by SimpleITK. The images are already resampled to an isotropic voxel size of [1, 1, 1] mm.

The images in the training and validation set can be defined in a list of dictionaries:

# simple example how to load the data:
import functions.setting.setting_utils as su

setting = su.initialize_setting(current_experiment='MyCurrentExperiment', where_to_run='Root')
data_exp_dict = [
    {'data': 'DIR-Lab_4D',           # Data to load. The image addresses can be modified in setting_utils.py
     'deform_exp': '3D_max7_D14_K',  # Synthetic deformation experiment
     'TrainingCNList': [1, 2, 3],    # Case number of images to load (The patient number)
     'TrainingTypeImList': [i for i in range(8)],    # Types images for each case number, for example [baseline, follow-up]
     'TrainingDSmoothList': [i for i in range(14)],  # The synthetic type to load. For instance, ['single_frequency', 'mixed_frequency']
     'ValidationCNList': [1, 2],
     'ValidationTypeImList': [8, 9],
     'ValidationDSmoothList': [0, 5, 10],
     },
    {'data': 'SPREAD',
     'deform_exp': '3D_max7_D14_K',
     'TrainingCNList': [i for i in range(1, 11)],
     'TrainingTypeImList': [0, 1],
     'TrainingDSmoothList': [i for i in range(14)],
     'ValidationCNList': [11, 12],
     'ValidationTypeImList': [0, 1],
     'ValidationDSmoothList': [0, 5, 10],
     },

]

setting = su.load_setting_from_data_dict(setting, data_exp_dict)
original_image_address = su.address_generator(setting, 'Im', data='DIR-Lab_4D', cn=1, type_im=0, stage=1)
print(original_image_address)

im_info_list_training = su.get_im_info_list_from_train_mode(setting, 'Training', load_mode='Single', read_pair_mode='Synthetic', stage=1)
im_info_list_training = im_info_list_training[0:4]
print('\n The first four elements are: ')
print(*im_info_list_training, sep="\n")

for im_info in im_info_list_training:
    im_info_su = {'data': im_info['data'], 'deform_exp': im_info['deform_exp'], 'type_im': im_info['type_im'],
                  'cn': im_info['cn'], 'dsmooth': im_info['dsmooth'], 'stage': im_info['stage'], }
    print(su.address_generator(setting, 'Im', **im_info_su))
    print(su.address_generator(setting, 'DeformedIm', **im_info_su))
./Data/DIR-Lab/4DCT/mha/case1/case1_T00_RS1.mha

 The first four elements are: 
{'data': 'DIR-Lab_4D', 'type_im': 0, 'cn': 1, 'deform_exp': '3D_max7_D14_K', 'dsmooth': 0, 'deform_method': 'respiratory_motion', 'deform_number': 0, 'stage': 1}
{'data': 'DIR-Lab_4D', 'type_im': 0, 'cn': 1, 'deform_exp': '3D_max7_D14_K', 'dsmooth': 1, 'deform_method': 'respiratory_motion', 'deform_number': 1, 'stage': 1}
{'data': 'DIR-Lab_4D', 'type_im': 0, 'cn': 1, 'deform_exp': '3D_max7_D14_K', 'dsmooth': 2, 'deform_method': 'respiratory_motion', 'deform_number': 2, 'stage': 1}
{'data': 'DIR-Lab_4D', 'type_im': 0, 'cn': 1, 'deform_exp': '3D_max7_D14_K', 'dsmooth': 3, 'deform_method': 'respiratory_motion', 'deform_number': 3, 'stage': 1}
./Data/DIR-Lab/4DCT/mha/case1/case1_T00_RS1.mha
./Elastix/Artificial_Generation/3D_max7_D14_K/DIR-Lab_4D/T00/case1/Dsmooth0/respiratory_motion_D0/DeformedImage.mha
./Elastix/Artificial_Generation/3D_max7_D14_K/DIR-Lab_4D/T00/case1/Dsmooth0/DNext1/NextIm.mha
./Elastix/Artificial_Generation/3D_max7_D14_K/DIR-Lab_4D/T00/case1/Dsmooth1/respiratory_motion_D1/DeformedImage.mha
./Elastix/Artificial_Generation/3D_max7_D14_K/DIR-Lab_4D/T00/case1/Dsmooth0/DNext2/NextIm.mha
./Elastix/Artificial_Generation/3D_max7_D14_K/DIR-Lab_4D/T00/case1/Dsmooth2/respiratory_motion_D2/DeformedImage.mha
./Elastix/Artificial_Generation/3D_max7_D14_K/DIR-Lab_4D/T00/case1/Dsmooth0/DNext3/NextIm.mha
./Elastix/Artificial_Generation/3D_max7_D14_K/DIR-Lab_4D/T00/case1/Dsmooth3/respiratory_motion_D3/DeformedImage.mha

'data':

The details of 'data' should be written in the setting.setting_utils.py. The general setting of each 'data' should be defined in load_data_setting(selected_data) like the extension, total number of types and default pixel value. The global data folder (setting['DataFolder']) can be defined in root_address_generator(where_to_run='Auto').

The details of the image address can be defined in setting.setting_utils.address_generator() after the line if data == 'YourOwnData':. For example you can take a look at the line 370: if data == 'DIR-Lab_4D':. The orginal images are defined with requested_address= 'OriginalIm'. To test the reading function, you can run the above script and check the original_image_address.

'deform_exp', 'TrainingDSmoothList':

check section 2.2.4 Setting of generating synthetic DVFs

'TrainingCNList', 'TrainingTypeImList':

'TrainingCNList' indicates the Case Numbers (CN) that you want to use for training. Usually each cn refers to a specific patient. 'TrainingTypeImList' indicates which types of the available images for each patient you want to load. For example in the SPREAD data, two types are available: baseline and follow-up. In the DIR-Lab_4D data, for each patient 10 images are available from the maximum inhale to maximum exhale phase.

2.2 Setting of generating synthetic DVFs

Four categories of synthetic DVF are available in the software: zero, single frequency, mixed frequency, respiratory motion

2.2.1 Zero 'zero'

2.2.2 Single frequency 'single_frequency'

For generating single-frequency DVF, we proposed the following algorithm:

  1. Initialize a B-spline grid points with a grid spacing of deform_exp_setting['SingleFrequency_BSplineGridSpacing'].
  2. Perturb the gird points in a smooth and random fashion.
  3. Interpolate to get the DVF.
  4. Normalize the DVF linearly, if it is out of the range [-deform_exp_setting['MaxDeform'], +deform_exp_setting['MaxDeform']]. By varying the spacing, different spatial frequencies are generated. alt text

Figure 1: Single Frequency: B-spline grid spacing are 40, 30 and 20 mm from left to right.

2.2.3 Mixed frequency 'mixed_frequency'

The steps for the mixed-frequency category is as follows:

  1. Extract edges with Canny edge detection method.
  2. Copy the binary image three times to get a vector of 3D image with the length of three.
  3. Set some voxels to be zero randomly for each image.
  4. Dilate the binary image for deform_exp_setting['MixedFrequency_Np'] iteration by using a random structure element for each image.
  5. Fill the binary dilated image with a DVF generated from the single-frequency method.
  6. Smooth the DVF with a Gaussian kernel with standard deviation of deform_exp_setting['MixedFrequency_SigmaRange']. The sigma is relatively small which leads to a higher spatial frequency in comparison with the filled DVF. By varying the sigma value and deform_exp_setting['MixedFrequency_BSplineGridSpacing'] in the filled DVF, different spatial frequencies will be mixed together.

alt text

Figure 2: Mixed Frequency.

2.2.4 'deform_exp', 'TrainingDSmoothList'

'deform_exp' is defined in the setting.artificial_generation_setting.py. For example you can use multiple types of single frequency and mixed frequency:

def_setting['DeformMethods'] = ['respiratory_motion', 'respiratory_motion', 'respiratory_motion', 'respiratory_motion',
                                 'single_frequency', 'single_frequency', 'single_frequency', 'single_frequency', 'single_frequency',
                                 'mixed_frequency', 'mixed_frequency', 'mixed_frequency', 'mixed_frequency',
                                 'zero']

The above setting is at the generation time. However, you might not want to load all of them at the reading time.

'ValidationDSmoothList': [2, 4, 8]: This means that you want to load translation type2, smoothBspline type1 and dilatedEdgeSmooth type 2.

2.3 Network

The proposed networks are given in Figure 3, 4, 5. alt text

Figure 3: unet1 (U-Net).

alt text

Figure 4: decimation4 (Multi-view).

alt text

Figure 5: crop4 (U-Net advanced).

2.4 Software Architecture

alt text

Figure 6: Software Architecture.

2.4.1 Memory efficiency

It is not efficient (or possible) to load all images with their DVFs to the memory. A DVF is three times bigger than its corresponding image with the type of float32. Alternatively, this software loads a chunk of images. The number of images per chunk can be chosen by the parameter: setting['NetworkTraining']['NumberOfImagesPerChunk']

setting['NetworkTraining']['NumberOfImagesPerChunk'] = 16  # Number of images that I would like to load in RAM
setting['NetworkTraining']['SamplesPerImage'] = 50
setting['NetworkTraining']['BatchSize'] = 15
setting['NetworkTraining']['MaxQueueSize'] = 20

2.4.2 Parallel Computing

We used threading in order to read patches in parallel with training the network. The class functions.reading.chunk_image.Images is defined to read images with threading.

regnet's People

Contributors

hsokooti avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

regnet's Issues

ImportError: No module named 'functions.RegNetModel.crop1_connection'

Traceback (most recent call last):
  File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/markemus/dev/RegNet/RegNet3D.py", line 17, in <module>
    import functions.RegNetModel as RegNetModel
  File "/home/markemus/dev/RegNet/functions/RegNetModel/__init__.py", line 17, in <module>
    from .crop1_connection import crop1_connection
ImportError: No module named 'functions.RegNetModel.crop1_connection'

I got this error trying to run RegNet3D.py. Commenting those imports fixed the error, and grep doesn't show any usage for them.

Code quality

Dear Hsokooti,

thanks for creating this repo. I have read your paper and it looks interesting... hence I would like to replicate it myself. However I do notice that (IMO) your code can be improved. Both in terms of readability as well as reduction in redundancy. Since I want to help to improve it, I would like to offer my suggestions. However I don't know how to 'post' these.. are you able to help me with that?

Besides that I have, for now, one question that I want to ask. In the script where you create synthetic DVFs, there is a option for blob() and smooth(). Both create a DVFb, but blob(), in addition, creates an DeformedArea surface.
As far as I could tell, the DVFb is introduced to transform one set of images to a transformed version. This fixed and moved image are then used to train the model.
However, I cant find where the DeformedArea object is used again. I believe you do show an image of that in the README.md, but I don't see that it transform some original image like DVFb does. Am I missing something?

SimpleITK ReadImage error

RuntimeError: Exception thrown in SimpleITK ReadImage: C:\Users\dashboard\Miniconda3\conda-bld\simpleitk_1521730316398\work\Code\IO\src\sitkImageReaderBase.cxx:89:
sitk::ERROR: The file "/srv/2-lkeb-16-reg1/hsokooti/DL/Elastix/LungExp/ExpLung11/Result/MovingImageFullRS1.mha" does not exist.

estimates the DVF from a pair of input images

Thanks for sharing the source code which is super helpful for understanding the method. Based on my understanding the network takes patches as input. May I ask how to estimate the DVF during testing given a pair of input images, ideally in a single shot?

DIR-lab 4DCT dataset

I'm trying to load the DIR-lab 4DCT dataset, but I've run into some trouble.

I downloaded the dataset from dir-lab.com, and applied the following pre-processing steps:

-downloaded all 10 data sets.
-renamed *-ssm.img to *_s.img (to silence errors from dirlab.py).
-ran the dirlab.py module on the img files.
-copied the resulting mha directory to /srv/mymachine/hsokooti/Data/DIR-Lab/4DCT/.

However, running RegNet3D.py throws the following error:

[MainThread  ] ---------------------------------Fri Nov  9 11:21:44 2018--------------------------------
[MainThread  ] ----------------------------------start experiment------------------------------
[MainThread  ] number of images in the last chunk=12
[MainThread  ] SyntheticDeformation[generation]: start DIR-Lab_4D/3D_max7_D9//T40/case2/Dsmooth0/DNext3/nextIm.mha
[MainThread  ] Process Process-1:
[MainThread  ] Traceback (most recent call last):
[MainThread  ]   File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
[MainThread  ]     self.run()
[MainThread  ]   File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
[MainThread  ]     self._target(*self._args, **self._kwargs)
[MainThread  ]   File "/home/markemus/dev/RegNet/functions/reading/direct_1st_epoch.py", line 49, in run
[MainThread  ]     self.fill()
[MainThread  ]   File "/home/markemus/dev/RegNet/functions/reading/direct_1st_epoch.py", line 85, in fill
[MainThread  ]     mode_synthetic_dvf='generation'
[MainThread  ]   File "/home/markemus/dev/RegNet/functions/synthetic_deformation.py", line 63, in get_dvf_and_deformed_images
[MainThread  ]     generate_next_im(setting, im_info=im_info)
[MainThread  ]   File "/home/markemus/dev/RegNet/functions/synthetic_deformation.py", line 495, in generate_next_im
[MainThread  ]     original_im_sitk = sitk.ReadImage(su.address_generator(setting, 'originalIm', **im_info_su))
[MainThread  ]   File "/home/markemus/.local/lib/python3.5/site-packages/SimpleITK/SimpleITK.py", line 8614, in ReadImage
[MainThread  ]     return _SimpleITK.ReadImage(*args)
[MainThread  ] RuntimeError: Exception thrown in SimpleITK ReadImage: /tmp/SimpleITK/Code/IO/src/sitkImageReaderBase.cxx:89:
[MainThread  ] sitk::ERROR: The file "/srv/markemus-VirtualBox/hsokooti/Data/DIR-Lab/4DCT/mha/case2/case2_T40_RS1.mha" does not exist.
[MainThread  ] total number of variables 1116164
[Thread-4    ] SyntheticDeformation[reading]: waiting 5s for DIR-Lab_4D/3D_max7_D9//T40/case2/Dsmooth0/DNext3/nextIm.mha

Is there an additional/different preprocessing step I need to do? Renaming */case2_T40.mha to */case2_T40_RS1.mha seems to fix this error and creates a nextIm.mha file, but there are more errors after that.

Performance issues in functions/registration/multi_stage.py(P2)

Hello,I found a performance issue in functions/registration/multi_stage.py ,
sess = tf.Session() was repeatedly called in for i_stage, stage in enumerate and was not closed.
I think it will increase the efficiency and avoid out of memory if you close this session after using it.

Here are two files to support this issue,support1 and support2

Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.

Loading data?

Hi @hsokooti I'm having trouble figuring out how to preprocess and load data into the net. I saw your post linking some datasets, but the model seems to load *.mha files and they all seem to be in other formats. Additionally, I'm not sure where to put the data in the directory structure; the documentation seems to be outdated and the setting["DLFolder"] key doesn't exist anymore.

patch concatenate

Hi, thank you for that your code have helped me a lot.
However, I am confused about some of your ideas, especially about patch.

  1. From the function “next_sweep_patch()” in functions.reading.real_pair, it seems that you get patches from image pair, and then concatenate these patches to form DVF. But this is easy to generate wrong correspondings between adjacent DVF. i.e. It's obvious that there are traces between patches in the final DVF. Like this:
    image

How to solve this problem?

  1. The input size in Function.RegNetModel.decimation3 is (155,155,155) and the output size is (27,27,27,3). Then the output is the DVF for the center part (27,27,27) in (155,155,155)?

Thanks a lot if you could answer my doubts.

where do we download data (SPREAD (Stolk et al., 2007) in your paper?

Dear @hsokooti and everyone,
I followed your paper and your paper used three datasets. They included:

  1. SPREAD (Stolk et al., 2007)
    1. DIR-Lab-4DCT and DIR-Lab-COPD

Where website can we download it? I also follow this link(https://www.resmedjournal.com/article/S0954-6111(07)00186-2/fulltext) but I don't know, how to download SPREAD (Stolk et al, 2007)? Please help me.
image

2)3) I download them from https://www.dir-lab.com/Downloads.html. okay with me.
Thank you, @hsokooti and everyone.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.