alanqrwang / keymorph Goto Github PK
View Code? Open in Web Editor NEWRobust multimodal image registration via keypoints
License: MIT License
Robust multimodal image registration via keypoints
License: MIT License
There is no Simple_Unet in functions.model, clean_mask these two classes, can you answer to help me?
Do you need to convert.npy files to.nii files to visualize the generated.nPY files?
brain_extraction_model.pth.tar How is the file generated?
Hi!, very nice work!
I'm trying to reproduce the preprocessing steps, but I don't seem to find the weights for the brain extraction: '../weights/brain_extraction_model.pth.tar'
. Are they in the repo? Would it be possible to include them in the release?
Thanks a lot!
Simon
I want to know how to calculate the average absolute displacement with using pytorch affine_grid and grid_sample functions.
Thank you
Maybe there is a versioning issue here? ConvNetFC
is in keymorph.net
and the module keymorph.step
does not exist.
In register.py import list:
from keymorph.model import ConvNetFC, ConvNetCoM
from keymorph.step import step
Hi again,
Thank you very much for publishing this wonderful work!
I am trying to compare your methods and am struggling to determine the best-performing variant model between KeyMorph and BrainMorph. After reading your MedIA paper and BrainMorph arXiv paper, BrainMorph might be expected to perform better since it is trained on full resolution and with a larger model, whereas KeyMorph is trained on half resolution. However, I didn't see a direct comparison in the papers. So, a quick question should be: Is BrainMorph expected to be better than Keymorph?
Hello! Thanks for your brilliant work!
I have some questions about the process of centroid calculation:
mx = vol.sum(dim=(2,3)) Mx = mx.sum(-1, True) + eps my = vol.sum(dim=(2,4)) My = my.sum(-1, True) + eps mz = vol.sum(dim=(3,4)) Mz = mz.sum(-1, True) + eps
Does it mean that the shape of input is [n_batch, channels, dim_Z, dim_Y, dim_X] ?
And what's the difference if I make such changes as below:
mx = vol.sum(dim=(3,4)) Mx = mx.sum(-1, True) + eps my = vol.sum(dim=(2,4)) My = my.sum(-1, True) + eps mz = vol.sum(dim=(2,3)) Mz = mz.sum(-1, True) + eps
Hi team
I'm using the registering code to register the example data as the instruction
python register.py \ --moving ./example_data/images/IXI_001.nii.gz \ --fixed ./example_data/images/IXI_002.nii.gz \ --load_path ./weights/numkey512_tps0_dice.4760.h5 \ --num_keypoints 512 \ --moving_seg ./example_data/labels/IXI_001.nii.gz \ --fixed_seg ./example_data/labels/IXI_002.nii.gz
I've also trigger the --save_preds
, but the same, we cannot find the saved registered image. Is there a output folder?
Best,
Zeyu
AttributeError: module 'torch' has no attribute 'amp'
How should this problem be solved
hello , where can I found brain_extraction_model.pth.tar?😃
Hi,
Thank you for releasing your code -- great work! 🤩
I'm trying to do deformable registration by training on the OASIS dataset. I've compiled my CSV file by taking all possible pairs from the training set. The first few rows look as follows:
fixed_img_path,moving_img_path,fixed_seg_path,moving_seg_path,fixed_mask_path,moving_mask_path,train
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0002_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0002_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0003_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0003_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0004_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0004_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0005_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0005_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0006_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0006_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0007_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0007_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0009_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0009_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0010_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0010_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0011_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0011_MR1/aligned_seg35.nii.gz,None,None,True
I'm using the following script to train:
python scripts/run.py --job_name oasis_seg --save_dir ./oasis-run-seg --num_keypoints 512 --loss_fn mse --transform_type tps_0 --data_path ./train_oasis_seg.csv --train_dataset csv --run_mode train --backbone truncatedunet --use_amp
But I get a validation Dice score of around 0.65 on the validation set:
which is not so good. I've verified that there are 36 labeled classes in the segmentation (1st channel is background and is ignored).
Training with the dice loss (--loss_fn dice
) does not help either.
Have you tried training with the OASIS dataset and have seen different results? Sharing the training scripts / pretrained models would be immensely useful.
Let me know if I'm missing something. Thanks again!
Hello! Thank your work
When I use the command to unzip the weights03 (cat weights03 | tar xzpvf -), something is wrong.
It reports:
gzip: stdin: not in gzip format
tar: Child died with signal 13
tar: Error is not recoverable: exiting now
Thank you for publishing this code! run.py fails on my system with
venv) tgreer@biag-w05:/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph$ python run.py --kp_align_method affine --num_keypoints 128 --loss_fn mse --eval --load_path ./weights/numkey128_aff_dice.1560.h5
{'affine_slope': -1,
'batch_size': 1,
'data_dir': './data/centered_IXI/',
'dataset': 'ixi',
'debug_mode': False,
'dim': 3,
'epochs': 2000,
'eval': True,
'gpus': '0',
'job_name': 'keymorph',
'kp_align_method': 'affine',
'kp_extractor': 'conv_com',
'kpconsistency_coeff': 0,
'load_path': './weights/numkey128_aff_dice.1560.h5',
'log_interval': 25,
'loss_fn': 'mse',
'lr': 3e-06,
'mix_modalities': False,
'norm_type': 'instance',
'num_keypoints': 128,
'num_test_subjects': 100,
'num_workers': 1,
'resume': False,
'save_dir': './output/',
'save_preds': False,
'seed': 23,
'steps_per_epoch': 32,
'tps_lmbda': None,
'transform': 'none',
'use_amp': False,
'use_wandb': False,
'visualize': False,
'wandb_api_key_path': None,
'wandb_kwargs': {},
'weighted_kp_align': False}
Number of GPUs: 2
Fixed train dataset has 3 modalities.
-> Modality T1 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality T2 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality PD has 427 subjects (427 images, 427 masks and 0 segmentations)
Moving train dataset has 3 modalities.
-> Modality T1 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality T2 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality PD has 427 subjects (427 images, 427 masks and 0 segmentations)
Test dataset has 3 modalities.
-> Modality T1 has 100 subjects (100 images, 100 masks and 0 segmentations)
-> Modality T2 has 100 subjects (100 images, 100 masks and 0 segmentations)
-> Modality PD has 100 subjects (100 images, 100 masks and 0 segmentations)
/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:32: UserWarning:
There is an imbalance between your GPUs. You may want to exclude GPU 1 which
has less than 75% of the memory or cores of GPU 0. You can do so by setting
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
environment variable.
warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
Model Summary
---------------------------------------------------------------
module.keypoint_extractor.module.block1.conv.weight
module.keypoint_extractor.module.block1.conv.bias
module.keypoint_extractor.module.block2.conv.weight
module.keypoint_extractor.module.block2.conv.bias
module.keypoint_extractor.module.block3.conv.weight
module.keypoint_extractor.module.block3.conv.bias
module.keypoint_extractor.module.block4.conv.weight
module.keypoint_extractor.module.block4.conv.bias
module.keypoint_extractor.module.block5.conv.weight
module.keypoint_extractor.module.block5.conv.bias
module.keypoint_extractor.module.block6.conv.weight
module.keypoint_extractor.module.block6.conv.bias
module.keypoint_extractor.module.block7.conv.weight
module.keypoint_extractor.module.block7.conv.bias
module.keypoint_extractor.module.block8.conv.weight
module.keypoint_extractor.module.block8.conv.bias
module.keypoint_extractor.module.block9.conv.weight
module.keypoint_extractor.module.block9.conv.bias
Total parameters: 8794496
---------------------------------------------------------------
Running test: subject id 0->0, mod T1->T1, aug rot0
Traceback (most recent call last):
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/run.py", line 771, in <module>
main()
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/run.py", line 616, in main
grid, points_f, points_m = registration_model(
^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
output.reraise()
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/_utils.py", line 644, in reraise
raise exception
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/model.py", line 45, in forward
points_f, points_m = self.extract_keypoints_step(img_f, img_m)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/model.py", line 69, in extract_keypoints_step
return self.keypoint_extractor(img1), self.keypoint_extractor(img2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
output.reraise()
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/_utils.py", line 644, in reraise
raise exception
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/net.py", line 85, in forward
out = self.block1(x)
^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/layers.py", line 128, in forward
out = self.conv(x)
^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 613, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
return F.conv3d(
^^^^^^^^^
TypeError: conv3d() received an invalid combination of arguments - got (tuple, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
* (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
didn't match because some of the arguments have invalid types: (tuple of (Tensor,), Parameter, Parameter, tuple of (int, int, int), tuple of (int, int, int), tuple of (int, int, int), int)
* (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
didn't match because some of the arguments have invalid types: (tuple of (Tensor,), Parameter, Parameter, tuple of (int, int, int), tuple of (int, int, int), tuple of (int, int, int), int)
Do you have any advice for proceeding?
Hi,
Thank you for publishing this wonderful work.
I am having difficulties extracting the original voxel space transformation matrix from the model. I need this matrix to warp volume-associated surface data using FreeSurfer.
From my understanding, the model expects inputs approximately in RAS space, handled by torchio.toCanonical, and then aligns keypoints in normalized space. This suggests the final output matrix is in normalized space, not the original voxel space.
Could you please guide me on how to obtain the 4x4 affine and rigid matrices in the original voxel space of my inputs?
Thank you very much for your help.
How to get a file with the suffix .tar.gz? the file in the data folder does not have a suffix.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.