Giter VIP home page Giter VIP logo

medsegdiff's People

Contributors

baiduihu avatar heikeyuhuajia avatar jiayuanz3 avatar jiwei0921 avatar lin-tianyu avatar nobleaustine avatar utkarshtambe10 avatar wujunde avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

medsegdiff's Issues

Why there are some unused parameters?

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel, and by
making sure all forward function outputs participate in calculating loss.

分割问题

I would like to ask which part of this part is to be intercepted from the path as the ID. Maybe my data is different from yours, and the error is reported after the code runs here.
elif args.data_name == 'BRATS': # slice_ID=path[0].split("_")[2] + "_" + path[0].split("_")[4] slice_ID=path[0].split("_")[-3] + "_" + path[0].split("slice")[-1].split('.nii')[0]

loss problem

image
(1) mse_diff here I understand is to predict the noise, target (noisy added) shape=[b,1,h,w], but model_output shape=[b,2,h,w], last issue you answer here two channels represent the mean and variance, can you explain the significance of them doing mse?
(2) loss_cal where target is the segmentation GT, does that model cal output represent the predicted segmentation result? Can the cal output be used directly to represent the segmentation accuracy of the model in the inference stage?
image

(3)Can you explain the meaning of sample, x_noisy, org, cal, cal_out respectively?

An error will be reported when the image size is set to 512

this error will occur
Original Traceback (most recent call last):
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 775, in forward
uemb, cal = self.highway_forward(c, [hs[3],hs[6],hs[9],hs[12]])
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 744, in highway_forward
return self.hwm(x,hs)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 2152, in forward
h = self.ffparserd
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 479, in forward
x = x * weight
RuntimeError: The size of tensor a (129) must match the size of tensor b (65) at non-singleton dimension 3

How to train this model on my own dataset

Thanks for the great work, I have a problem. there are four parts of my dataset ------train_images, train_mask, test_images, test_mask. There are no jason documents or csv document. should I create a jason document in coco form for my dataset or just use the mask images. thank you!!

About DDTI dataset

Hi Wu, I am a beginner to the medical imaging processing. Could you share the DDTI dataset and example cases? Thanks a lot.

Problems of traing

I encountered the following problems when training with BRATS dataset!Can you help me?Thanks!

File "D:\jace\pythonProject\MedSegDiff-master\MedSegDiff-master\guided_diffusion\train_util.py", line 83, in init
self._load_and_sync_parameters()
File "D:\jace\pythonProject\MedSegDiff-master\MedSegDiff-master\guided_diffusion\train_util.py", line 139, in _load_and_sync_parameters
dist_util.sync_params(self.model.parameters())
File "D:\jace\pythonProject\MedSegDiff-master\MedSegDiff-master\guided_diffusion\dist_util.py", line 76, in sync_params
dist.broadcast(p, 0)
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

Problems during sample

I run the segmentation_sample.py, and meet the problem:

Logging to /root/autodl-tmp/MedSegDif/med_results/img_out/
creating model and diffusion...
sampling...
no dpm-solver
/root/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py:1709: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
File "MedSegDif/med_scripts/segmentation_sample.py", line 163, in
main()
File "MedSegDif/med_scripts/segmentation_sample.py", line 109, in main
sample, x_noisy, org, cal, cal_out = sample_fn(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 553, in p_sample_loop_known
for sample in self.p_sample_loop_progressive(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 624, in p_sample_loop_progressive
out = self.p_sample(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 435, in p_sample
out = self.p_mean_variance(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/respace.py", line 90, in p_mean_variance
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 319, in p_mean_variance
model_mean, _, _ = self.q_posterior_mean_variance(
File "/root/autodl-tmp/./MedSegDif/med_guided_diffusion/gaussian_diffusion.py", line 219, in q_posterior_mean_variance
assert x_start.shape == x_t.shape
AssertionError

I know that it is because x_start.shape is not equal to x_t.shape. However, My dataset is similar to ISICDataset, so I feel very strange.
Thanks a lot if you can reply.

training&inference time cost

Hi! I am tring to use your code on BRATS2020 with sliced input image.

If I follow README and use this command on one single 3090GPU(24G), what is the correct time cost?

image

btw, could you please share one of your training log as an example? many thx!!

Sampling output image visualization problem?

Hello,

I'm currently training this model on my own dataset, I have created a separate dataloader python file, the file and folder structure of the dataset is the same as ISIC. No other code other than segmentation_train.py and segmentation_sample.py was changed just to load the data. The model is trained for 30000 steps so far. But when I tried to use the segmentation_sample.py for the test images, I am getting these masks.

Are these mask outputs normal for this model?
89_16_output_ens
89_24_output_ens

python 3.8.16
torch 1.13.1
torchvision 0.14.1
torchsummary 1.5.1
opencv 4.7.0.68
scikit-image 0.19.3

the hyperparameters setting issue

Thank you for you great job!

Can I leverage these parameters "--diffusion_steps 50 --dpm_solver True " in training process?

请问**--diffusion_steps 50 --dpm_solver True**的参数设置可以被用于training过程中吗?
还是他们只能被用于sampling过程?

DPM-Solver Memory Problem

Hi! Thanks for your excellent work. I successfully trained a MedSegDiff-B model on my dataset but have trouble sampling.

Specifically, while using DPM-Solver to sample, the memory usage of GPU improves with the 'num_ensemble' parameter. In every ensemble model(1/5), the GPU memory improves around 2GB and finally collapses with the "CUDA out of memory" error.

This problem allows me to sample only one image before the inference process collapses. Is this a normal phenomenon? If not, how can I deal with it?

PS: using the original inference process can sample images without increasing GPU memory.

Question about the mse loss for training segmentation tasks

Thanks for your great work and your effort on sharing this code. Here I am wondering that, is it stable to use mse loss for training segmentation tasks? Usually we use cross-entropy loss to train this task and this is what i am curious about.

Thanks for reading this issue and I am looking forward to your reply!

when i use dpm-solver,cuda out of memory

i looked other's issue,someone said need to change the pytorch version to 1.8.1,but when i try it it won't work,i also tried other version of pytorch still won't work

Pretrained model

Hi, could you please provide your pre-trained models? I train a model, but the sampling result is not right. The max value of the pixel is about 10, so the pictures are all black.

The number of steps used in training

Hi,

It is an excellent project to share with. I have a question when running the program. Is the number of steps set to 1000 during training and use only 100 steps during inference?

Thanks if the question can be answered~

Best,
CaviarLover

Epochs of training

May I ask how many epochs do you train to obtain the result in this paper?

For multi-class seg, num_class=3 for example

Hi Junde Wu,

I have some questions for you.

The hyper-parameter in_ch=2 is fixed no matter of binary or multi-class task, where the two dimension includes the image and the mask.
For multi-calss task, what we are supposed to change is only the calibration output, i.e. sigmoid to softmax, then we can get a [1 3 H W] calibration and a [1 2 H W] model_output. Is that correct?

If we change the in_ch = 3 + 1(one-hot with the image condition), we can have the [1 3 H W] calibration, however, i do not know what is the model_output? is it something like [1 3 2 H W]? or it is also the [1 2 H W], if so, using mask rather then one-hot as the input of diffusion model seems to be meaningful?

I grouped a 5-class task into binary case to check the results. Here are one visualization, is it correct? From top to bottom, img, recovery from diffusion model, calibration, linear combination of the recovery and calibration.
image

Thanks!
Ping

Problem of dimension

I am curious that why the model output channel dimension is 2, my output is [b image_size image_size], but your code need output [b 2 image_size image_size].
image

brats data slice

hello, may I know how did you slice the 3D brats data into 2D data in order to put it in the directory?

error in create_argparser

defaults.update({k: v for k, v in model_and_diffusion_defaults().items() if k not in defaults})
Hi, i believe this is what you want to have, otherwise the value will be overwriten by those in the predefined values

Problem about calculating loss

Hello, I run scripts/ segmentation_train.py on my own datasets , and I meet the problem:
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel; (2) making sure all forward function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).

Thank you !

questions about getting test scores

Are the scores in tableⅠ of MedSedDiff the official test scores or 5-fold cross validation scores? nii.gz files are needed to be uploaded to BraTS2020 website to get the official test score, but I cannot find nii.gz generating section in related source code so I don't know how to get the test score of trained model.

about the training

When I running the scripts/segmentation_train.py have a problem.

Traceback (most recent call last):
File "D:\jace\pythonProject\MedSegDiffv2-master\scripts\segmentation_train.py", line 110, in
main()
File "D:\jace\pythonProject\MedSegDiffv2-master\scripts\segmentation_train.py", line 62, in main
TrainLoop(
File "D:\jace\pythonProject\MedSegDiffv2-master\guided_diffusion\train_util.py", line 83, in init
self._load_and_sync_parameters()
File "D:\jace\pythonProject\MedSegDiffv2-master\guided_diffusion\train_util.py", line 139, in _load_and_sync_parameters
dist_util.sync_params(self.model.parameters())
File "D:\jace\pythonProject\MedSegDiffv2-master\guided_diffusion\dist_util.py", line 78, in sync_params
dist.broadcast(p, 0)
File "C:\SoftWare\python 3.10\lib\site-packages\torch\distributed\distributed_c10d.py", line 1408, in broadcast
work.wait()
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

Problems of sample

  1. log.txt and progress.csv will not output anything.
  2. when segmentation_sample.py is running, the terminal says there are connection-errors.
  3. i set num_ensemble=5 but just get one output image. from the terminal, it seems that something stop the iteration.

what parameters or arguments should i revise or what can i do?

JY4%GR1O_}XW{2_%}F7SR
E5%P7GDZ1 _U9BP98GBNYN
8LP) @{L0 RJDNGP1 EHE8](https://user-images.githubusercontent.com/113956389/206350520-8345aa80-6bf8-4edc-926f-26822d33a874.png) ![I6F2R41@NC8B$R33E9GX5L

Loading custom MRI datasets

Hi!

Thanks for this repo, really exciting stuff!

I have a sagittal MRI dataset that has the following dimensions: (512,512,7) (H, W, Slice) in NIFTI format. How should the input be for the network to train? In my understanding, since the autoencoder is 2D U-Net, the networks will be trained on each slice of each patient individually, however, I'm a bit confused about the input to network should be.

When will the training stop?

Thank you for your excellent job. I wonder how many iterations will be used for training since I do not find the condition to stop training. Thank you.

BRATS Dataset training testing split

Hi there, nice work.
Can you provide me your training and testing split for the BRATS21 dataset? I am trying to reproduce your work so I would like to know how to create the actual samples I need to train and infer upon. In the paper you wrote Train/validation/test sets are split following the default settings of the dataset , but their validation and test split sets don't have labels. Can you tell me how to find them?

Also did you do any preprocessing except slicing the images from 3D to 2D?

question about loss calculation

I have a question regarding loss calculation:
for training loss = (losses["loss"] * weights + losses['loss_cal'] * 10).mean() is used.
Why do you weigh the direct prediction of the ground truth higher compared to the comparison with a less noisy version?

Is there a reason that for inference depending on the Dice-score, different composition of cal and sample is used.

Thanks in advance!

Code Training Problem

Depending on the author, after downloading the dataset and running the code, the following issues arise。
image
Then after debug, we found that the length of hs is 12, so we changed the index to 11 and ran it again, which caused the following problems
image
image
Please help the author to answer, thank you。

Model'difference when I run segmentation_sample.py

Excellent work!
I'm a beginner in the field of deep learning.
I have a question that when I run segmentation_sample.py, what's the difference between savedmodel_XXXX.pt, optsavedmodel_XXXX.pt, emasavedmodel_XXXX.pt.
Thanks a lot.

Training with different image size

Training with image size = 128 and ISIC dataset fails.

Failing occurs when values are going through FFParser.
image

I assume this is related to hardcoded values in the instantiation of FFParser modules (unet.py file):
image

Do you have the same issue and what would be a smart fix?

Sample Visualization and Metrics

We have trained the diffusion model in more than 100,000 steps and sampled the test images.
However, the predictions seem wrong as pixel values vary from 0 to tens instead of 0,1.
How to obtain the final segmentation mask?
0000015_output

Questions about the forward function of the UNetMode

Dear authors:

I have some questions about the function of the highway_forward (Generic_UNet). Detailed as follows:

  1. On the ISIC dataset, the resolution of the input image is [batch, 3, 64, 64], which means that c is [batch, 3, 64, 64]. But the hs[12] is out of range, so we have changed the index as 2, 5, 8, and 11, corresponding to the resolution of […,64,64], […,32,32], […,16,16], and […,8,8], respectively. This operation is right?
  2. The h and hb on L768 have different resolutions. Should it be resized?

image

  1. When calculating the x=xhahb in froward function of the Generic_UNet, x, ha, and hb have different resolution. Should it be resized?

image

I hope for your response sincerely. Thanks a lot!

V2版本代码什么时候公布

对于扩散做语义分割给予了厚望,但用在自己的数据集V1的代码结果不太理想,希望能够尽早发布V2版本

When args.in_ch = 5, the following error will occur

Original Traceback (most recent call last):
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 773, in forward
h = module(h, emb)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/works/MedSegDiff/guided_diffusion/unet.py", line 86, in forward
x = layer(x)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 446, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/data1/ppw/anaconda3/envs/PyTorch/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [128, 5, 3, 3], expected input[8, 4, 256, 256] to have 5 channels, but got 4 channels instead

Forward function in Generic_UNet

image
Hello, I notice the code in Generic_UNet define a conv in the forward function, using
image
and it will use a different weight in next call, can you tell me the reason?

sample running error

Error using official run mode.

mentation_sample.py --data_dir /home/yp/diskdata/workspace/medsegdiff/dataset/ISIC --model_path /home/yp/diskdata/workspace/medsegdiff/results/savedmodel020000.pt --image_size 256 --num_channels 128 --class_cond False --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16 --diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --num_ensemble 5

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.