Giter VIP home page Giter VIP logo

Comments (7)

HKervadec avatar HKervadec commented on August 16, 2024

The solution is to pre-compute offline the distance map in 3D, save them into a .npy in the with the axises kxyz, with k being the class axis. Pay attention to spatial resolution at this step -- the scipy function has an extra, optional parameter for that.

Then, in the dataloader, you load the 3d distmap as is (no extra transform besides converting to a tensor), and subpatch it as the same time as the original image.

You can then do the usual multiplication between distance map and softmaxes.

This is what I implemented for the extension, but didn't had time to put it in the repo yet. I will do so soon, and then point to the exact code parts doing that, but this should already give you a rough idea on how to proceed. Let me know if you need other details in the meantime.

Hoel

from boundary-loss.

xychenunc avatar xychenunc commented on August 16, 2024

The solution is to pre-compute offline the distance map in 3D, save them into a .npy in the with the axises kxyz, with k being the class axis. Pay attention to spatial resolution at this step -- the scipy function has an extra, optional parameter for that.

Then, in the dataloader, you load the 3d distmap as is (no extra transform besides converting to a tensor), and subpatch it as the same time as the original image.

You can then do the usual multiplication between distance map and softmaxes.

This is what I implemented for the extension, but didn't had time to put it in the repo yet. I will do so soon, and then point to the exact code parts doing that, but this should already give you a rough idea on how to proceed. Let me know if you need other details in the meantime.

Hoel

Yes, this is a solution that can solve part of the problem because for training data that are obtained via data augmentation, it seems there is no better choice other than computing the SDM on the fly. In this case, training the network efficiently can be a big issue.

BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?

Thanks

from boundary-loss.

HKervadec avatar HKervadec commented on August 16, 2024

With respect to the data augmentation

there is no better choice other than computing the SDM on the fly

When you refer to "on the fly", you mean to compute the distance map inside the loss function ?

The way I see it, the pre-computed distance map can be augmented as well, just like we perform the augmentation on the original ground truth. The overall code would look like this:

from pathlib import Path
from Typing import Dict, List, Tuple

from torch import Tensor
from torch.utils.data import Dataset


class DistDataset(Dataset):
        def __init__(self, *args, **kwargs):
                ...
                self.files: List[Tuple[Path, Path, Path]]

        def __getitem__(index: int) -> Dict[str, Tensor]:
                img_path, gt_path, dist_path = self.files[index]

                # ... perform the transforms here

                aug_img, aug_gt, aug_dist = augment(img, gt, dist)
                del img, gt, dist  # Avoid returning those by accident

                return {"img": aug_img,  # CWH shape
                        "gt": aug_gt, # KWH shape
                        "distmap": aug_dist}  # KWH shape


# Then in the training loop
α = 0.01
for data in train_loader:
        imgs = data["img"].to(device)  # BKWH shape
        gts = data["gt"].to(device)  # BKWH shape
        dists = data["distmap"].to(device)  # BKWH shape

        optimizer.zero_grads()

        pred_probs = softmax(net(imgs))

        dsc_loss = DiceLoss(gts, pred_probs)
        bl_loss = BoundaryLoss(dists, pred_probs)

        total_loss = dsc_loss + α * bl_loss
        total_loss.backward()
        optimizer.step()

BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?

In our experiments, it was somewhat sensitive, but still consistently gave some improvement even with a sub-optimal value: Table 3 in our extension:

Screenshot_2020-11-29 Boundary loss for highly unbalanced segmentation - 1812 07032 pdf

Increasing and rebalancing the values were not only better in perf, but also much simpler to tune -- to me that is their main advantage.

from boundary-loss.

xychenunc avatar xychenunc commented on August 16, 2024

from boundary-loss.

xychenunc avatar xychenunc commented on August 16, 2024

from boundary-loss.

HKervadec avatar HKervadec commented on August 16, 2024

Thanks for your response. I tried your loss function on our dataset, however, I have not seen improved performance till now. I want to know if the simplification from the differential form to integral form hold for 3D cases as the 2D example you show in your paper? The reason I ask this question is that I think maybe the simplification does not hold for 3D boundary and 3D surface. If it still hold, could you please clarify and send me some reference articles to show that. Thanks!

Yes the result still holds, though in 3D you need to take into account the spatial resolution of each axis, as it might differ. The updated distance computation function now looks like this:

def one_hot2dist(seg: np.ndarray, resolution: Tuple[float, float, float] = None,
                 dtype=None) -> np.ndarray:
    assert one_hot(torch.tensor(seg), axis=0)
    K: int = len(seg)

    res = np.zeros_like(seg, dtype=dtype)
    for k in range(K):
        posmask = seg[k].astype(np.bool)

        if posmask.any():
            negmask = ~posmask
            res[k] = eucl_distance(negmask, sampling=resolution) * negmask \
                - (eucl_distance(posmask, sampling=resolution) - 1) * posmask
        # The idea is to leave blank the negative classes
        # since this is one-hot encoded, another class will supervise that pixel

    return res

resolution = None correspond to sampling = (1, 1, 1)

Another thing to take into account: if the space between each slice becomes too big (like 1cm on the z axis while it is 1mm on the x and y axises), then maybe the 3D distance will not make much sense. It will depend on your application.

Also, what does ‘rebalance’ mean and how to implement it? Thanks

Rebalancing correspond to start with a high weight for the DSC loss weight, and a smaller one on the boundary loss, and to slowly shift them:

α = 0.01


for e in range(epochs):
        for data in train_loader:
                ...

                total_loss = (1 - α) * dsc_loss + α * bl_loss
                total_loss.backward()
                optimizer.step()

        α = max(α + 0.01, 0.99)

from boundary-loss.

xychenunc avatar xychenunc commented on August 16, 2024

I tried to understand the the mathematics in your paper. It is interesting to see the beautiful connection between Eq 2 and 3. However, I found it difficult to understand your derivation to connect the two. Specifically, in the paper, you mentioned that the two can be connected using the following:
Screen Shot 2020-12-02 at 5 22 17 PM
To me, it is not obvious why the first two are equivalent because getting dD_G/dq is not a constant for the second term after the minus sign in
Screen Shot 2020-12-02 at 5 28 04 PM is also related to q and I think cannot be easily formulated.
Could you please explain more on this?

from boundary-loss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.