Giter VIP home page Giter VIP logo

Comments (7)

HKervadec avatar HKervadec commented on September 18, 2024

Ah, yes, my bad.
Most of the utils works with shape bcwh, except one_hot2dist. And class2one_hot adds automatically an axis if needed:

def class2one_hot(seg: Tensor, C: int) -> Tensor:
    if len(seg.shape) == 2:  # Only w, h, used by the dataloader
        seg = seg.unsqueeze(dim=0)
    assert sset(seg, list(range(C)))
    assert len(seg.shape) == 3, seg.shape

    b, w, h = seg.shape  # type: Tuple[int, int, int]
    ...

This is why I added an itemgetter(0) in the transform. The code evolved in a weird way, and I also had to play around with utils expecting Tensor or np.ndarray

So what happens, and how to fix this (type hint won't work but useful to explain):

mask: np.ndarray[hw]
mask_tensor: Tensor[hw] = torch.tensor(mask, dtype=torch.int64)
mask_onehot: Tensor[chw] = class2one_hot(mask_tensor, 2)[0]  # because the res is bchw
mask_distmap: Tensor[chw] = one_hot2dist(mask_onehot.cpu().numpy())

I will probably change this behavior in the future ; this is too error prone right now. But at least the assertions makes it easy to catch it.

Thanks for your great work. I am really enjoy it.

Glad you find it useful!

from boundary-loss.

GWwangshuo avatar GWwangshuo commented on September 18, 2024

Hi @HKervadec , thanks for your quick reply. I can run the code now; however, I encountered another problem. I have done something like this:

# transform ground truth (mask) to dist maps when loading images
mask_tensor = torch.tensor(mask, dtype=torch.int64)
mask_onehot = class2one_hot(mask_tensor, 2)[0]
mask_distmap = one_hot2dist(mask_onehot.cpu().numpy())
mask_distmap = torch.from_numpy(mask_distmap).float()

and

# For training purpose
region_loss = GeneralizedDice(idc=[0, 1])
surface_loss = SurfaceLoss(idc=[1])  

for input_image, dist_maps in dataloader:
    # input_image: bwh
    # dist_maps: bcwh
    optimizer.zero_grad()

    output_logits = net(input_image)  # bcwh
    output_softmaxes = F.softmax(output_logits, dim=1)  # bcwh

    loss = region_loss(outputs_softmaxes, dist_maps, None)  + surface_loss(outputs_softmaxes, dist_maps, None)  

    loss.backward()
    optimizer.step()

During training, I encountered following problems:
image

The GDL (region loss) is negative which is incorrect. I go back to check its input outputs_softmaxes and dist_maps as follows:

outputs_softmaxes
I check corresponding shape and values along with each dimensions .
image
and
dist_maps
image

Which one of outputs_softmaxes or dist_maps goes wrong? Can you give me some hints? I am really appreciate. Thanks.

Best!

from boundary-loss.

HKervadec avatar HKervadec commented on September 18, 2024

The region loss does not take the distance map as an input, but mask_onehot:

for input_image, dist_maps in dataloader:
    ...
    loss = region_loss(outputs_softmaxes, dist_maps, None) + surface_loss(outputs_softmaxes, dist_maps, None)  
    ...

should be

for input_image, onehot_labels, dist_maps in dataloader:
    ...
    loss = region_loss(outputs_softmaxes, onehot_labels, None) + surface_loss(outputs_softmaxes, dist_maps, None)  
    ...

But funny that the distance maps pass the simplex assertion in the GDL loss ; did not realize it could. I guess I should replace it with one_hot (which is literally simplex and sset([0, 1])

from boundary-loss.

GWwangshuo avatar GWwangshuo commented on September 18, 2024

Well, thanks for your generous help. I can train and evaluate now. But there are still something that I am not sure whether they are correct or not. e.g.,

After few epoch training, I get these results below:

contour_loss=0.01294, region_loss=0.55355, total_loss=0.56649 (training stage) --> jaccard: 0.00000
contour_loss=-0.00360, region_loss=0.42092, total_loss=0.41732 (training stage) --> jaccard: 0.60041
contour_loss=-0.01631, region_loss=0.23017, total_loss=0.21387 (training stage) --> jaccard: 0.62058
contour_loss=-0.01778, region_loss=0.23821, total_loss=0.22043 (training stage) --> jaccard: 0.63152
...

Why the contour loss will become negative? Is this negative contour loss correct or not? Thanks.

Best!

from boundary-loss.

HKervadec avatar HKervadec commented on September 18, 2024

Is this negative contour loss correct or not

Yes. The distance map is signed ; negative inside the object, and positive outside of it. So a perfect segmentation will be multiplied only with negative distances: the optimal value is negative.

from boundary-loss.

bluesky314 avatar bluesky314 commented on September 18, 2024

Should we stop training if it is negative? I trained on dice on a small dataset for 20 epochs and I then put your surface loss which started at around -0.7

Edit: I did continue and with decreasing the learning rate I got a 0.2+ on my dice score :) . I saw the training procedure in your paper about decreasing alpha. Would like any insights to how dice and surface loss interact and any training advice for better optimization.

Thanks, good work

from boundary-loss.

HKervadec avatar HKervadec commented on September 18, 2024

Actually forgot to reply before closing it.

Basically you stop once the loss function does not improve anymore (reaching convergence), and/or validation dice stopped improving. This does not differ on how you usually decide to stop a training.

The best achievable value for the boundary loss would be (- (distance(posmask) - 1) * posmask).mean(): there is a perfect overlap between the predicted object and the ground truth, so we are summing only negative distances inside the object. Notice that this optimal value will be different for each image.

from boundary-loss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.