Comments (7)
The solution is to pre-compute offline the distance map in 3D, save them into a .npy
in the with the axises kxyz
, with k
being the class axis. Pay attention to spatial resolution at this step -- the scipy function has an extra, optional parameter for that.
Then, in the dataloader, you load the 3d distmap as is (no extra transform besides converting to a tensor), and subpatch it as the same time as the original image.
You can then do the usual multiplication between distance map and softmaxes.
This is what I implemented for the extension, but didn't had time to put it in the repo yet. I will do so soon, and then point to the exact code parts doing that, but this should already give you a rough idea on how to proceed. Let me know if you need other details in the meantime.
Hoel
from boundary-loss.
The solution is to pre-compute offline the distance map in 3D, save them into a
.npy
in the with the axiseskxyz
, withk
being the class axis. Pay attention to spatial resolution at this step -- the scipy function has an extra, optional parameter for that.Then, in the dataloader, you load the 3d distmap as is (no extra transform besides converting to a tensor), and subpatch it as the same time as the original image.
You can then do the usual multiplication between distance map and softmaxes.
This is what I implemented for the extension, but didn't had time to put it in the repo yet. I will do so soon, and then point to the exact code parts doing that, but this should already give you a rough idea on how to proceed. Let me know if you need other details in the meantime.
Hoel
Yes, this is a solution that can solve part of the problem because for training data that are obtained via data augmentation, it seems there is no better choice other than computing the SDM on the fly. In this case, training the network efficiently can be a big issue.
BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?
Thanks
from boundary-loss.
With respect to the data augmentation
there is no better choice other than computing the SDM on the fly
When you refer to "on the fly", you mean to compute the distance map inside the loss function ?
The way I see it, the pre-computed distance map can be augmented as well, just like we perform the augmentation on the original ground truth. The overall code would look like this:
from pathlib import Path
from Typing import Dict, List, Tuple
from torch import Tensor
from torch.utils.data import Dataset
class DistDataset(Dataset):
def __init__(self, *args, **kwargs):
...
self.files: List[Tuple[Path, Path, Path]]
def __getitem__(index: int) -> Dict[str, Tensor]:
img_path, gt_path, dist_path = self.files[index]
# ... perform the transforms here
aug_img, aug_gt, aug_dist = augment(img, gt, dist)
del img, gt, dist # Avoid returning those by accident
return {"img": aug_img, # CWH shape
"gt": aug_gt, # KWH shape
"distmap": aug_dist} # KWH shape
# Then in the training loop
α = 0.01
for data in train_loader:
imgs = data["img"].to(device) # BKWH shape
gts = data["gt"].to(device) # BKWH shape
dists = data["distmap"].to(device) # BKWH shape
optimizer.zero_grads()
pred_probs = softmax(net(imgs))
dsc_loss = DiceLoss(gts, pred_probs)
bl_loss = BoundaryLoss(dists, pred_probs)
total_loss = dsc_loss + α * bl_loss
total_loss.backward()
optimizer.step()
BTW, I want to know if the segmentation performance is sensitive to the value of loss weight for boundary loss?
In our experiments, it was somewhat sensitive, but still consistently gave some improvement even with a sub-optimal value: Table 3 in our extension:
Increasing and rebalancing the values were not only better in perf, but also much simpler to tune -- to me that is their main advantage.
from boundary-loss.
from boundary-loss.
from boundary-loss.
Thanks for your response. I tried your loss function on our dataset, however, I have not seen improved performance till now. I want to know if the simplification from the differential form to integral form hold for 3D cases as the 2D example you show in your paper? The reason I ask this question is that I think maybe the simplification does not hold for 3D boundary and 3D surface. If it still hold, could you please clarify and send me some reference articles to show that. Thanks!
Yes the result still holds, though in 3D you need to take into account the spatial resolution of each axis, as it might differ. The updated distance computation function now looks like this:
def one_hot2dist(seg: np.ndarray, resolution: Tuple[float, float, float] = None,
dtype=None) -> np.ndarray:
assert one_hot(torch.tensor(seg), axis=0)
K: int = len(seg)
res = np.zeros_like(seg, dtype=dtype)
for k in range(K):
posmask = seg[k].astype(np.bool)
if posmask.any():
negmask = ~posmask
res[k] = eucl_distance(negmask, sampling=resolution) * negmask \
- (eucl_distance(posmask, sampling=resolution) - 1) * posmask
# The idea is to leave blank the negative classes
# since this is one-hot encoded, another class will supervise that pixel
return res
resolution = None
correspond to sampling = (1, 1, 1)
Another thing to take into account: if the space between each slice becomes too big (like 1cm on the z axis while it is 1mm on the x and y axises), then maybe the 3D distance will not make much sense. It will depend on your application.
Also, what does ‘rebalance’ mean and how to implement it? Thanks
Rebalancing correspond to start with a high weight for the DSC loss weight, and a smaller one on the boundary loss, and to slowly shift them:
α = 0.01
for e in range(epochs):
for data in train_loader:
...
total_loss = (1 - α) * dsc_loss + α * bl_loss
total_loss.backward()
optimizer.step()
α = max(α + 0.01, 0.99)
from boundary-loss.
I tried to understand the the mathematics in your paper. It is interesting to see the beautiful connection between Eq 2 and 3. However, I found it difficult to understand your derivation to connect the two. Specifically, in the paper, you mentioned that the two can be connected using the following:
To me, it is not obvious why the first two are equivalent because getting dD_G/dq is not a constant for the second term after the minus sign in
is also related to q and I think cannot be easily formulated.
Could you please explain more on this?
from boundary-loss.
Related Issues (20)
- Need to calculate the background-class loss for Multi-class Segment? HOT 1
- change of sign in surface loss HOT 2
- Generalise DiceLoss never decrease HOT 3
- Meaning of -1 in the calculation of distance maps HOT 2
- Does einsum really make the code easier to understand HOT 2
- ISLES 2018 HOT 1
- Heterogeneous resolution yields non-zero boundary. HOT 5
- InvalidArgumentError: required broadcastable shapes at loc(unknown) [Op:Mul] HOT 2
- Can this loss be used for multi-label classification? HOT 4
- Create dist_map for image segmentation mask as label. HOT 2
- Is multiplication by negmask in one_hot2dist() irrelevant? HOT 2
- Question about the optional argument resolution in the dist_map_transform function HOT 1
- About the calculation of dist_map HOT 5
- how to use with sigmoid as activation function when meeting binary classification segmentation task HOT 3
- how to adjust the lambda parameter HOT 5
- How to use HausdorffLoss? HOT 1
- How to use HausdorffLoss? HOT 1
- How to one-hot encode a multi-class dataset and how to use Boundary Loss on B x N x W x H logits? HOT 2
- Only using boundary loss leads to non convergence HOT 1
- Failure of matching datasets of WMH HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from boundary-loss.