Comments (4)
Hi Ray,
Thanks for those pointers. I had forgotten about torchio (though I think it's still fairly new ?), it is a good reminder. Looks much simpler to use than sub-patching and reconstructing the 3D volume manually.
From a quick overview I would guess that the transforms handle image and label in a different fashion.
I will check that in more details after the MIDL deadline, but adding another implementation for a third input, as a distance map, might be a good option.
I guess that work would fall on me, it would be relevant for the community to be able to manipulate distance maps -- beyond its use for boundary loss. "Incompatible" transforms (if any) could simply raise an exception explaining why it cannot be done.
Hoel
from boundary-loss.
Hi, and thanks for the interest in our work.
I'm trying give you a partial reply now, but I might come back in a few days once I've had more time to think about it.
I didn't had to perform data augmentation for that work, though the options are there for offline data augmentation: https://github.com/LIVIAETS/boundary-loss/blob/master/preprocess/slice_wmh.py#L133
I.e., it will create N
augmented copies for each image. (Though now that I think of it, I should probably invert the two loops now that I pre-compute a 3D distance map.)
In proper 3D, if you subpatch the volume, then yeah, you cannot even recompute a correct distance map based solely on the patch, so pre-computing is the only way.
Online data augmentation (what you are referring to) is indeed much more tricky. The timing of your question is interesting, as a few days back I was experimenting with it, noting that augmenting a ground-truth that is already one-hot encoded will break the simplex for a few pixels -- even a simple rotation would break it.
Now, a distance map might be a bit more resilient to that, depending what interpolation you are using (I would say, BILINEAR
in place of NEAREST
, to handle elastic deformations) and what actual transforms you are performing. The way I see it (in my mind, this is theoretical), the following transforms would be possible:
- horizontal/vertical flip
- rotation
- translation
- affine (in the sense of rotation + translation)
- and possibly a few others
As such, I would say that you can apply the same transforms, with different interpolation:
- image:
BILINEAR
- class-labels:
NEAREST
- one-hot labels: don't
- disttance map:
BILINEAR
Feel free to post code-snippets of the data augmentation that you are using, it might help to test the feasiblity of it.
from boundary-loss.
Thanks for your reply Hoel! I didn't realize that you had the option of computing data augmentation offline as well as caching the distance map in individual slices. This was very useful to see. You are correct where the purposes of my work, I'm performing the segmentation in proper 3D where the input is a volume subpatch. Overlapping volume subpatches are created, the segmentation is performed on each subpatch and are all assembled in the end to provide a volumetric 3D segmentation. Because of the additional degree of freedom from the z direction, there would need to be a lot more offline copies of the volume for augmentation than there would be slices so I haven't explored the avenue and will opt for online augmentation.
The transforms you've listed also make perfect sense. In fact I may have to resort to just using these if I want to incorporate the boundary loss into what I'm doing. However for perspective, I am using torchio
as a framework to provide a batch of volume subpatches in the training loop which I can feed directly into the model so that the weights can ultimately be updated.
Here is a function I wrote that returns the list of transforms I apply to training + validation:
import torchio as io
from typing import Tuple
def get_transforms() -> Tuple[tio.Compose, tio.Compose]:
"""
Returns transform chains for training and validation/testing datasets.
Returns:
A tuple of torchio.Compose objects for the transform chains.
"""
training_transform = tio.Compose([
tio.ToCanonical(),
tio.Resample(1),
tio.RandomMotion(p=0.2),
tio.RandomBiasField(p=0.3),
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
tio.RandomNoise(p=0.5),
tio.RandomFlip(),
tio.OneOf({
tio.RandomAffine(): 0.8,
tio.RandomElasticDeformation(): 0.2,
}),
])
validation_transform = tio.Compose([
tio.ToCanonical(),
tio.Resample(1),
tio.ZNormalization(masking_method=tio.ZNormalization.mean)
])
return training_transform, validation_transform
In order, ToCanonical
reorients the volume so that the scans are in radiological RAS format (Right to Left, Anterior to Posterior, Superior to Interior), Resample
resamples the voxels so that each one is normalized to a size of 1 mm^3 iso. RandomMotion
and RandomBias
randomly add in a random motion and bias effect that would be seen in MRIs if the subject were to move during the scan and intensity level variations due to field inhomogeneity. We finally z-normalize, apply a random noise field, randomly flip one of the axes then choose one of a random affine transform or random elastic deformation. The validation transforms are of course only a subset where we only represent the volume in canonical form, resample to 1 mm^3 iso and z-normalize.
After, you can build a list of torchio.Subject
s then put it into torchio.SubjectsDataset
. Here's the training set for example where there are a list of paths for the volume and corresponding labels:
# Build a list of subjects - one for each volume
subjects_train = [tio.Subject(
volume=tio.ScalarImage(image_path),
labels=tio.LabelMap(label_path))
for (image_path, label_path) in zip(image_paths_train, label_paths_train)]
# Build dataset given this list
dataset_train = tio.SubjectsDataset(
subjects_train, transform=training_transform)
Finally you can provide this to a torchio.Queue
which you can wrap around a torch.utils.data.DataLoader
:
import torch
patch_size = 256, 256, 16
sampler = tio.data.UniformSampler(patch_size)
patches_training_set = tio.Queue(
subjects_dataset=dataset_train,
max_length=300,
samples_per_volume=20,
sampler=sampler,
num_workers=8,
shuffle_subjects=True,
shuffle_patches=True,
)
training_loader_patches = torch.utils.data.DataLoader(
patches_training_set, batch_size=16)
The queue here is to efficiently provide volume subpatches as you'll have CPU workers fill it up while you are dequeuing patches for the model update. What is nice is that for the labels, they use a torchio.LabelMap
so any intensity transformations are ignored as these are treated as categorical labels (obviously). Therefore I can see using this to wrap around the signed distance maps so we only work with transformations that would be applicable to using segmentation masks. Judging from your discussion, I can see that this should be the course of action to take.
However going with this, you can then use this and start training your models. The tio.SubjectsDataset
will already provide augmented volume subpatches by first doing the transform on the whole volume, then providing the correct subpatches into this augmented volume when you iterate over the data loader:
# Define model and loss criterion here
model = ...
criterion = ...
# Define optimizer
optimizer = torch.optim.Adam(model.parameters())
for batch_idx, batch in enumerate(training_loader_patches):
# move to GPU
inputs = batch['volume'][tio.DATA].cuda()
targets = batch['labels'][tio.DATA].cuda()
# find the loss and update the model parameters accordingly
optimizer.zero_grad()
preds = model(inputs)
loss = criterion(preds, target)
loss.backward()
optimizer.step()
I can see how using at most affine transformations on the distance map would properly reflect the actual distances to the ground truth surface with the exception of the scale. Given the above, I'd like to know what your thoughts are when you come back to this, especially since I'm using a framework that directly provides volume subpatches in an easy manner.
Thanks!
from boundary-loss.
Thanks for your comment Hoel! Yes torchio
is relatively new. I would always have to create the data loader myself to extract subvolumes for training, but this makes it a lot easier to do it so I started using it just recently. I've got some deadlines myself but after they're finished I wouldn't mind looking at this myself and seeing if I could make a PR to add this functionality into the current framework. Until next time!
from boundary-loss.
Related Issues (20)
- Does einsum really make the code easier to understand HOT 2
- ISLES 2018 HOT 1
- Heterogeneous resolution yields non-zero boundary. HOT 5
- InvalidArgumentError: required broadcastable shapes at loc(unknown) [Op:Mul] HOT 2
- Can this loss be used for multi-label classification? HOT 4
- Create dist_map for image segmentation mask as label. HOT 2
- Is multiplication by negmask in one_hot2dist() irrelevant? HOT 2
- Question about the optional argument resolution in the dist_map_transform function HOT 1
- About the calculation of dist_map HOT 5
- how to use with sigmoid as activation function when meeting binary classification segmentation task HOT 3
- how to adjust the lambda parameter HOT 5
- How to use HausdorffLoss? HOT 1
- How to use HausdorffLoss? HOT 1
- How to one-hot encode a multi-class dataset and how to use Boundary Loss on B x N x W x H logits? HOT 2
- Only using boundary loss leads to non convergence HOT 1
- Failure of matching datasets of WMH HOT 1
- Is it possible to train the Boundary Loss code on a GPU? HOT 1
- Whether this loss function can be applied to the partition of a hollow region, that is, a region with two boundaries HOT 2
- License Request
- zero question
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from boundary-loss.