Giter VIP home page Giter VIP logo

Comments (18)

HKervadec avatar HKervadec commented on August 16, 2024

Hey,
It should be fairly easy to use. I assume you have N images in your batch, and 4 classes.

Since most of my code uses one-hot-encoding, you will need to convert the true segmentation mask to it, with class2one_hot (contained in the utils.py file).

Once you have the one-hot-encoding of shape Nx4x224x224, you can feed that into the one_hot2dist function. Note that this one is CPU defined (it uses the numpy dist transform), so I highly recommend to perform that inside the data-loader for better performances.

After that, you should be able to use the SurfaceLoss class defined in losses.py. Notice that the class has one parameter, idc:

class SurfaceLoss():
    def __init__(self, **kwargs):
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        self.idc: List[int] = kwargs["idc"]
    ...

I use it to select which classes I want to supervise with that loss (sometimes only 1 class among the 4, sometimes classes [1, 2, 3], sometimes all of them). Depending on what you want to deal with your own 4 classes, you can either use it or simply remove that argument.

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

from boundary-loss.

HKervadec avatar HKervadec commented on August 16, 2024

In our paper (binary problem, very high imbalance) we have been adding it to GDL. I think it would be a good fit also for your problem, as the boundary loss is used to refine the segmentation learned by the GDL.

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

Can I just use Dice Loss? Or would you recommend GDL?

Right now I have DL implemented, but have been trying to decide whether to switch to GDL.

Mainly deciding between DL+Surface or GDL+Surface.

from boundary-loss.

HKervadec avatar HKervadec commented on August 16, 2024

I would recommend GDL, but you can try both. It's just that vanilla dice loss was not working at all for our very unbalanced problems.

We also have a GDL implementation in our repository, taking as well a one-hot encoded label mask.

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

Are the class2one_hot and one_hot2dist compatible with albumentations?

https://github.com/albu/albumentations

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

Also, how do they affect performance?

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

Albumentations provides convenient transformations for the segmentation mask, but I am afraid may break due to one hot encoding

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

I'll also find out once I test

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

Though all of the changes you've made to your workflow to accommodate the distance and ground truth maps I think will be disruptive to my existing workflow, where I can use the other loss functions interchangeably.

from boundary-loss.

HKervadec avatar HKervadec commented on August 16, 2024

You could adapt my implementation to suit your needs, for instance by putting the class2one_hot into the loss function directly. I still recommend to put the computation of the distance map into the data loader (by returning an extra Tensor for instance), as the distance transform is available in numpy only.

I have never used albumentations (nor heard of it before), but since at the end of the day images are converted as Tensor I don't see why you could not chain them.

Something that might be of use to you: I recently implemented a faster class2one_hot, especially at higher number of classes (difference will be small on binary problems).

def class2one_hot(seg: Tensor, K: int) -> Tensor:
    if len(seg.shape) == 2:  # Only w, h, used by the dataloader
        return class2one_hot(seg.unsqueeze(dim=0), K)[0]

    assert sset(seg, list(range(K))), uniq(seg)
    assert len(seg.shape) == 3, seg.shape

    b, w, h = seg.shape  # type: Tuple[int, int, int]

    res: Tensor = torch.stack([seg == c for c in range(K)], dim=1).type(torch.int32)
    assert res.shape == (b, K, w, h)
    assert one_hot(res)

    return res


def fast_class2one_hot(seg: Tensor, K: int) -> Tensor:
    if len(seg.shape) == 2:  # Only w, h, used by the dataloader
        return fast_class2one_hot(seg.unsqueeze(dim=0), K)[0]

    assert sset(seg, list(range(K))), uniq(seg)
    assert len(seg.shape) == 3, seg.shape

    b, w, h = seg.shape  # type: Tuple[int, int, int]

    device = seg.device
    res = torch.zeros((b, K, w, h), dtype=torch.int32, device=device).scatter_(1, seg[:, None, ...], 1)

    assert res.shape == (b, K, w, h)
    assert one_hot(res)

    return res

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

Wonderful. Thanks for all of this help.

from boundary-loss.

jlevy44 avatar jlevy44 commented on August 16, 2024

I’ll let you know if I need anything else. Thanks!

from boundary-loss.

chenz97 avatar chenz97 commented on August 16, 2024

Hi @HKervadec , I'm also trying to get surface loss work on may own problem. According to the discussions above, I did:
(1) given a GT mask, first convert it into a one-hot array of (K, H, W), then call one_hot2dist to get an output of (K, H, W), which may be roughly in the range of [-100, 100] (depends on image size, typically with abs value small than 1000)
(2) given an output probability map of (N, K, H, W), and the batched distance map (N, K, H, W) as calculated above, call SurfaceLoss with idc removed.
Are the above steps correct? Did I miss anything?
In my experiments I found that the loss is easily negative, and the background class (also encoded in one channel) dominates the loss, which be can up to a magnitude of -100. So I tried removing the background channel and also the simplex and onehot check, but still always got negative losses, although much closer to zero. I have trained with Dice loss and got a reasonably high IoU of 0.75 in 100 iterations before adding the surface loss with decreased alpha. As the training goes, I even saw a gradual decrease in performance. Is there any possible solutions? Thank you very much!

from boundary-loss.

HKervadec avatar HKervadec commented on August 16, 2024

Hey there,

(1) given a GT mask, first convert it into a one-hot array of (K, H, W), then call one_hot2dist to get an output of (K, H, W), which may be roughly in the range of [-100, 100] (depends on image size, typically with abs value small than 1000)

That sounds correct

(2) given an output probability map of (N, K, H, W), and the batched distance map (N, K, H, W) as calculated above, call SurfaceLoss with idc removed.

How many classes do you have ? I would keep idc, and set it to idc=[1] for binary problems. You can keep the simplex and onehot assertions. Those are useful to make sure you feed the correct data to the loss, and you use idc to supervise only the foreground.

The loss can indeed be negative (since you are summing negative distances), and in fact the perfect value is negative (optimal value is unique to each image).

As the training goes, I even saw a gradual decrease in performance. Is there any possible solutions?

It is really application specific, you will need to balance the two, and/or use some scheduling on the weights between the two losses. But first, I recommend to test with idc=[1] (or idc=[1,2,...,K-1] if you have K classes).

Let me know if something isn't clear, or have other questions.

Hoel

from boundary-loss.

chenz97 avatar chenz97 commented on August 16, 2024

Hi @HKervadec , thanks for your reply. I have K classes which is input-specific, so I removed idc, but when using all classes including background, I did use the simplex and onehot.

But first, I recommend to test with idc=[1] (or idc=[1,2,...,K-1] if you have K classes).

I have already tested with this, and that's why I remove self.idc in my case. It yields better training behavior, but as mentioned above, I saw negative values from the beginning and gradual performance decrease.
I'd like to know if it makes sense to have negative loss values just at the point of adding the loss? Or it should generally be positive for a relatively long period of time before going negative? Thank you!

from boundary-loss.

HKervadec avatar HKervadec commented on August 16, 2024

Woops, time is flying fast. Sorry for the delayed reply.

I'd like to know if it makes sense to have negative loss values just at the point of adding the loss? Or it should generally be positive for a relatively long period of time before going negative? Thank you!

I have never really checked that. But if the boundary loss is negative right away (and doesn't go much lower), it probably means that the segmentation is already in a pretty good shape (and/or in a local minima of the first loss). Perhaps then yes, it would make sense to add the boundary loss earlier, or bits by bit (as we did in some experiments, by raising it's weight over time).

I am afraid I cannot help you much more with that one. It is really dataset and task specific.

You could try (sorry if I missed something you already did):

  • Different scheduling strategies for the boundary loss weight
  • Apply the boundary loss only on a subset of the classes (the most difficult ones ?)
  • Use a different boundary loss weight for each class (even more if the classes have a different size).

Let me know,

Hoel

from boundary-loss.

chenz97 avatar chenz97 commented on August 16, 2024

Hi @HKervadec , thank you for your suggestions! I will try them.

from boundary-loss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.