Giter VIP home page Giter VIP logo

probabilistic-unet-pytorch's Introduction

Probabilistic UNet in PyTorch

A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch. This is a pytorch implementation of this paper https://arxiv.org/abs/1806.05034, for which the code can be found here: https://github.com/SimonKohl/probabilistic_unet.

Adding KL divergence for Independent distribution

In order to implement an Gaussian distribution with an axis aligned covariance matrix in PyTorch, I needed to wrap a Normal distribution in a Independent distribution. Therefore you need the add the following to the PyTorch source code at torch/distributions/kl.py (source: pytorch/pytorch#13545).

def _kl_independent_independent(p, q):
    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
        raise NotImplementedError
    result = kl_divergence(p.base_dist, q.base_dist)
    return _sum_rightmost(result, p.reinterpreted_batch_ndims)

Training

In order to train your own Probabilistic UNet in PyTorch, you should first write your own data loader. Then you can use the following code snippet to train the network

train_loader = define this yourself
net = ProbabilisticUnet(no_channels,no_classes,filter_list,latent_dim,no_fcomb_convs,beta)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
for epoch in range(epochs):
    for step, (patch, mask) in enumerate(train_loader): 
        patch = patch.to(device)
        mask = mask.to(device)
        mask = torch.unsqueeze(mask,1)
        net.forward(patch, mask, training=True)
        elbo = net.elbo(mask)
        reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers)
        loss = -elbo + 1e-5 * reg_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Train on LIDC Dataset

One of the datasets used in the original paper is the LIDC dataset. I've preprocessed this data and stored them in a pickle file, which you can download here. After downloading the files you should place them in a folder called 'data'. After that, you can train your own Probabilistic UNet on the LIDC dataset using the simple train script provided in train_model.py.

probabilistic-unet-pytorch's People

Contributors

justusschock avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

probabilistic-unet-pytorch's Issues

kl loss is nan

Thanks for the code, i have a question, when i use my own picture data to train,in the class AxisAlignedConvGaussian,
self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers, posterior=self.posterior)
self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1)
the conv_layer will always output a large value(1000+),when it's output is uesd for this code
dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1)
because of "torch.exp", it will output NAN,so i want to know why there is no need to add "torch.sigmoid" to limit the value after the conv_layer

Issues when training on my own datasets

Hi,
Thank you for the code. The code here is quite helpful. I successfully implemented it on LIDC dataset and wanted to train my own medical image dataset. The dataset size is 41*41 and training samples are about 8000. The datasets are well preprocessed and loaded through my dataloader. However, when I train it, I got an asset error from net.forward(patch, mask, training=True) where error happens in assert up.shape[3] == bridge.shape[3] in unet_block.py.
After examing the error in detail, I find it comes from RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 12 but got size 11 for tensor number 1 in the list. when it tries to cat up and bridge out = torch.cat([up, bridge], 1).

I am a beginner in this field. Could you give me any insights on what is bridge here and how to tackle this problem? Thanks in advance.

Posterior net input

In the original code in tensorflow, before concatenating image with segmentation ground truth, there is a line of code seg-=0.5. But I do not see it in the Pytorch version. Is it necessary to do it before feeding into the posterior net?

Why you randomly sample one label from 4 lables in load_LIDC_data.py?

Hi,

The code is helpful, thanks. However, I noticed that you are randomly sampling one ground truth label out of 4 in your file - load_LIDC_data.py. I guess this is not what the original paper means. Could you explain a bit about why you did this? Any comment is helpful. Thank you!

Best
Stella

KL Divergence for Independent

For parametrizing an axis aligned Gaussian you are using a Normal wrapped into an Indepedent, and add a patch for the undefined KL diveregence.

I was wondering isn't it possible to achieve the same (axis aligned multivariate gaussian) using a MultivariateNormal instance? For example:

mu = torch.zeros(batch_size, latent_dim)
log_sigma = torch.ones(batch_size, latent_dim)
cov = torch.stack([torch.diag(sigma) for sigma in torch.exp(log_sigma)])

mvn = MultivariateNormal(mu, cov)

mvn.batch_shape, mvn.event_shape
(torch.Size([batch_size]), torch.Size([latent_dim]))

considering KL is defined for a (MultivariateNormal, MultivariateNormal)

NaN in training

Hi,

Sorry to bother you again. I got your code working with the data you provided, but unfortunately after only a few steps the regularization loss AND the elbo become NaN (and don't recover from this). Have you experienced something similar?

I'm using Pytorch from latest master btw.

Thanks again,
Justus

EDIT: This does not seem to be a problem with the data. I tried to overfit it on 5 samples and it worked well for 1700 batches (batchsize=4) but then the NaNs occured

Predict a test image?

Dear stefanknegt,

I am very happy to see your pytorch implementation of Probabilistic Unet. I am using it for a large-scale research, and later will cite your work in.

May I ask, I have been successfully training the model. But in the validation step, I don't see any script about it. For simplicity, may I ask how can we predict a single image? (is it the output of reconstruct() ? )

Please help me. Thanks so much for your support.

Kind regards,
Kha.

Segmentation Fault (Core dumped)

Hi,
after trying to use your code, i got the following messages:

Loading file output_part2.pickle
Loading file output_part3.pickle
Loading file output_part5.pickle
Loading file output_part4.pickle
Loading file output_part1.pickle
Number of training/test patches: (13602, 1511)
No of classes: 1
No of input channels: 1
No of filters first layer: 32
Padding: 1
Segmentation fault (core dumped)

The training script is run with default parameters, I used recent master of your repo and pytorch (and added the custom kl divergence) and the Segfault seems to happen, during model initialization.

Do you have any ideas on how this could happen?

Preprocessing

Hi #stefanknegt!

Thank you very much for your great job.

Could you please provide more information regarding the preprocessing step done in the provided data? Did you follow the same steps as in the original paper?
LIDC has 1010 patients but I see in your datafile 875 unique ids. Did you make any additional filter? Also in your data, there are 15096 images, while in the paper they used 12874 (8882,1996,1996). Did you include more slices per nodule?

Best!

Does it require multiple ground truth Y for each image X to train?

It is a brilliant idea to make prior conditional on X. However, in order to learn a proper posterior distribution while avoiding mode collapse, I wonder if multiple ground truth Y's for each image X is absolutely necessary? Can training work even when there is only one ground truth Y for each X? I guess the posterior net can encode the variations in ground truth Y's, and then pass its knowledge to the prior net. Without various ground truth Y's, can the posterior net "guess" where the ambiguity most likely to occur and create a non-degenerate distribution?
Any comment is helpful. Thank you!

sample

Hi,could you tell me the sampling process?
The sample function in the file(Probablistic_unet.py) did not been used.
Thank you~

data preprocessing

Dear Sir,
I found the model to be very sensitive to data preprocessing. When I train the model with raw LIDC data with my own preprocessing, the kl loss becomes NaN soon. Could you please describe how you do data preprocessing?

data pre-processing

Can you provide code for data preprocessing? Although I have downloaded the data set you provided.

Results

Hi,

first of all: The implementation looks really nice. Did you run any trainings and are your results comparable to the ones mentioned in the paper?

Thanks!
Justus

Sampling a Segmentation

Hi Stefan, thanks for this nice PyTorch implementation. I have a question regarding the sample function in the probabilistic_unet.py file. In line 222 (while training), wouldn't we want to sample from the posterior net? (according to the paper's training heuristic, correct me if I am wrong)

Getting NAN tensor from encoder

Hello - I've been getting this issue consistently while running the code as is, with the LIDC-IDRI data (downloaded from the provided link). This error is caught at the line dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1) within the class AxisAlignedConvGaussian() due to the fact that mu = tensor([[nan, nan], [nan, nan], [nan, nan], [nan, nan], [nan, nan]], device='cuda:0', grad_fn=<SliceBackward0>).

When I trace back where nan is coming from, it's directly from the encoder (output from encoding = self.encoder(input)) all the way back to the output from the forward method in the Encoder class.

this issue seems to be persistent regardless of batch size (I've run it with batch size 5 and 10, and I still get the error within the first epoch, randomly after a few runs).

I've verified the input, it seems okay, the images are what is expected (viewed) and some masks have all 0's while others have some values. Nothing out of the ordinary.

I have yet to be able to track down why this is occurring. It seems like others have experienced a similar issue, but more on the loss side, the issue I'm experiencing is within the forward pass, so it is independent of the loss.

Any insight would be appreciated!
ValueError: Expected parameter loc (Tensor of shape (10, 2)) of distribution Normal(loc: torch.Size([10, 2]), scale: torch.Size([10, 2])) to satisfy the constraint Real(), but found invalid values: tensor([[nan, nan], [nan, nan], [nan, nan], [nan, nan], [nan, nan], [nan, nan], [nan, nan], [nan, nan], [nan, nan], [nan, nan]], device='cuda:0', grad_fn=<SliceBackward0>)

On the Problem of Pickle Dataset

Hello, I downloaded a dataset in pickle format from the author's link, but encountered an error indexing a non-existent index value while training the network. I have not encountered this format of dataset before. How can I solve it?

Testing

Hello!
Thank you for your code. I don't understand how to test it after training. Maybe do you have some examples of code for testing or recommendations?

reconstruction_loss is very large

Thansk for code,when i use my own data to train,my reconstruction_loss is very large(200000000+). I later find that it is because of the code
reconstruction_loss = criterion(input=self.reconstruction, target=segm)
Because your loss is overall, its value will be very large, why not take the average?
When i use the average , the reconstruction_loss is normal(value is 100-)

Loding pickle file with python 3

When I try to 'pickle.loads()' my machine threw 'not found pydicom'. I guess this is a problem with python3 loading python2 pickles? Can you kindly suggest a way around? Thanks!

learning rate?

Did you adopt the lr scheduler in the experiments ? or it's ok for 1e-4 all the time?

testing code

thank you for your training code
would you please provide the testing code?
thank you

Unet features

The U net seems that it is not being trained (which is in accordance to the paper). However, you do not load any weights (pretrained model) into your U Net model as far as I was able to see. This means that unet weights are as initialized.

Is this a correct observation? Was the purpose of the code just showcasing the implementation of Probabilistic UNet, because I assume it will be difficult to train without a trained segmentation network.

About the four expert labels

The label of the downloaded dataset contains four expert labels. Should I deal with the four expert labels, take their union or what.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.