Giter VIP home page Giter VIP logo

keymorph's People

Contributors

alanqrwang avatar evanmy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

keymorph's Issues

Weights for brain extraction

Hi!, very nice work!

I'm trying to reproduce the preprocessing steps, but I don't seem to find the weights for the brain extraction: '../weights/brain_extraction_model.pth.tar'. Are they in the repo? Would it be possible to include them in the release?

Thanks a lot!
Simon

register.py not running

Maybe there is a versioning issue here? ConvNetFC is in keymorph.net and the module keymorph.step does not exist.

In register.py import list:

from keymorph.model import ConvNetFC, ConvNetCoM
from keymorph.step import step

Is Brainmorph expected to be better than Keymorph?

Hi again,

Thank you very much for publishing this wonderful work!

I am trying to compare your methods and am struggling to determine the best-performing variant model between KeyMorph and BrainMorph. After reading your MedIA paper and BrainMorph arXiv paper, BrainMorph might be expected to perform better since it is trained on full resolution and with a larger model, whereas KeyMorph is trained on half resolution. However, I didn't see a direct comparison in the papers. So, a quick question should be: Is BrainMorph expected to be better than Keymorph?

Details about functions.layers.CenterOfMass3d()

Hello! Thanks for your brilliant work!
I have some questions about the process of centroid calculation:

mx = vol.sum(dim=(2,3)) Mx = mx.sum(-1, True) + eps my = vol.sum(dim=(2,4)) My = my.sum(-1, True) + eps mz = vol.sum(dim=(3,4)) Mz = mz.sum(-1, True) + eps

Does it mean that the shape of input is [n_batch, channels, dim_Z, dim_Y, dim_X] ?
And what's the difference if I make such changes as below:

mx = vol.sum(dim=(3,4)) Mx = mx.sum(-1, True) + eps my = vol.sum(dim=(2,4)) My = my.sum(-1, True) + eps mz = vol.sum(dim=(2,3)) Mz = mz.sum(-1, True) + eps

Cannot find the saved registered image

Hi team

I'm using the registering code to register the example data as the instruction

python register.py \ --moving ./example_data/images/IXI_001.nii.gz \ --fixed ./example_data/images/IXI_002.nii.gz \ --load_path ./weights/numkey512_tps0_dice.4760.h5 \ --num_keypoints 512 \ --moving_seg ./example_data/labels/IXI_001.nii.gz \ --fixed_seg ./example_data/labels/IXI_002.nii.gz

I've also trigger the --save_preds, but the same, we cannot find the saved registered image. Is there a output folder?

Best,
Zeyu

Results on OASIS data

Hi,

Thank you for releasing your code -- great work! 🤩

I'm trying to do deformable registration by training on the OASIS dataset. I've compiled my CSV file by taking all possible pairs from the training set. The first few rows look as follows:

fixed_img_path,moving_img_path,fixed_seg_path,moving_seg_path,fixed_mask_path,moving_mask_path,train
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0002_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0002_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0003_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0003_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0004_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0004_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0005_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0005_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0006_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0006_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0007_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0007_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0009_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0009_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0010_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0010_MR1/aligned_seg35.nii.gz,None,None,True
/data/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0011_MR1/aligned_norm.nii.gz,/data/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz,/data/OASIS_OAS1_0011_MR1/aligned_seg35.nii.gz,None,None,True

I'm using the following script to train:

python scripts/run.py --job_name oasis_seg --save_dir ./oasis-run-seg --num_keypoints 512 --loss_fn mse --transform_type tps_0 --data_path ./train_oasis_seg.csv --train_dataset csv --run_mode train --backbone truncatedunet --use_amp

But I get a validation Dice score of around 0.65 on the validation set:
image

which is not so good. I've verified that there are 36 labeled classes in the segmentation (1st channel is background and is ignored).

Training with the dice loss (--loss_fn dice) does not help either.

Have you tried training with the OASIS dataset and have seen different results? Sharing the training scripts / pretrained models would be immensely useful.

Let me know if I'm missing something. Thanks again!

How to get the weights of brain extractor

Hello! Thank your work
When I use the command to unzip the weights03 (cat weights03 | tar xzpvf -), something is wrong.
It reports:
gzip: stdin: not in gzip format
tar: Child died with signal 13
tar: Error is not recoverable: exiting now

Evaluation with run.py fails

Thank you for publishing this code! run.py fails on my system with

venv) tgreer@biag-w05:/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph$ python run.py --kp_align_method affine --num_keypoints 128 --loss_fn mse --eval                 --load_path ./weights/numkey128_aff_dice.1560.h5
{'affine_slope': -1,
 'batch_size': 1,
 'data_dir': './data/centered_IXI/',
 'dataset': 'ixi',
 'debug_mode': False,
 'dim': 3,
 'epochs': 2000,
 'eval': True,
 'gpus': '0',
 'job_name': 'keymorph',
 'kp_align_method': 'affine',
 'kp_extractor': 'conv_com',
 'kpconsistency_coeff': 0,
 'load_path': './weights/numkey128_aff_dice.1560.h5',
 'log_interval': 25,
 'loss_fn': 'mse',
 'lr': 3e-06,
 'mix_modalities': False,
 'norm_type': 'instance',
 'num_keypoints': 128,
 'num_test_subjects': 100,
 'num_workers': 1,
 'resume': False,
 'save_dir': './output/',
 'save_preds': False,
 'seed': 23,
 'steps_per_epoch': 32,
 'tps_lmbda': None,
 'transform': 'none',
 'use_amp': False,
 'use_wandb': False,
 'visualize': False,
 'wandb_api_key_path': None,
 'wandb_kwargs': {},
 'weighted_kp_align': False}
Number of GPUs: 2
Fixed train dataset has 3 modalities.
-> Modality T1 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality T2 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality PD has 427 subjects (427 images, 427 masks and 0 segmentations)
Moving train dataset has 3 modalities.
-> Modality T1 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality T2 has 427 subjects (427 images, 427 masks and 0 segmentations)
-> Modality PD has 427 subjects (427 images, 427 masks and 0 segmentations)
Test dataset has 3 modalities.
-> Modality T1 has 100 subjects (100 images, 100 masks and 0 segmentations)
-> Modality T2 has 100 subjects (100 images, 100 masks and 0 segmentations)
-> Modality PD has 100 subjects (100 images, 100 masks and 0 segmentations)
/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:32: UserWarning:
    There is an imbalance between your GPUs. You may want to exclude GPU 1 which
    has less than 75% of the memory or cores of GPU 0. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.
  warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))

Model Summary
---------------------------------------------------------------
module.keypoint_extractor.module.block1.conv.weight
module.keypoint_extractor.module.block1.conv.bias
module.keypoint_extractor.module.block2.conv.weight
module.keypoint_extractor.module.block2.conv.bias
module.keypoint_extractor.module.block3.conv.weight
module.keypoint_extractor.module.block3.conv.bias
module.keypoint_extractor.module.block4.conv.weight
module.keypoint_extractor.module.block4.conv.bias
module.keypoint_extractor.module.block5.conv.weight
module.keypoint_extractor.module.block5.conv.bias
module.keypoint_extractor.module.block6.conv.weight
module.keypoint_extractor.module.block6.conv.bias
module.keypoint_extractor.module.block7.conv.weight
module.keypoint_extractor.module.block7.conv.bias
module.keypoint_extractor.module.block8.conv.weight
module.keypoint_extractor.module.block8.conv.bias
module.keypoint_extractor.module.block9.conv.weight
module.keypoint_extractor.module.block9.conv.bias
Total parameters: 8794496
---------------------------------------------------------------

Running test: subject id 0->0, mod T1->T1, aug rot0
Traceback (most recent call last):
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/run.py", line 771, in <module>
    main()
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/run.py", line 616, in main
    grid, points_f, points_m = registration_model(
                               ^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/model.py", line 45, in forward
    points_f, points_m = self.extract_keypoints_step(img_f, img_m)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/model.py", line 69, in extract_keypoints_step
    return self.keypoint_extractor(img1), self.keypoint_extractor(img2)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/net.py", line 85, in forward
    out = self.block1(x)
          ^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/keymorph/layers.py", line 128, in forward
    out = self.conv(x)
          ^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 613, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/playpen-raid1/tgreer/equivariant_reg_2/keymorph_comparison/keymorph/venv/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
    return F.conv3d(
           ^^^^^^^^^
TypeError: conv3d() received an invalid combination of arguments - got (tuple, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (tuple of (Tensor,), Parameter, Parameter, tuple of (int, int, int), tuple of (int, int, int), tuple of (int, int, int), int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (tuple of (Tensor,), Parameter, Parameter, tuple of (int, int, int), tuple of (int, int, int), tuple of (int, int, int), int)

Do you have any advice for proceeding?

Extracting Original Voxel Space Transformation Matrix from Model Output

Hi,

Thank you for publishing this wonderful work.

I am having difficulties extracting the original voxel space transformation matrix from the model. I need this matrix to warp volume-associated surface data using FreeSurfer.

From my understanding, the model expects inputs approximately in RAS space, handled by torchio.toCanonical, and then aligns keypoints in normalized space. This suggests the final output matrix is in normalized space, not the original voxel space.

Could you please guide me on how to obtain the 4x4 affine and rigid matrices in the original voxel space of my inputs?

Thank you very much for your help.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.