Comments (2)
A few comments on this.
- Your initial proposal seems to be limited to binary classes and vector datasets. However, class imbalance is a problem faced by all 2+ class raster and vector datasets. It would be ideal to have an implementation that works for all of these.
- We should draw inspiration from (or at least point out) PyTorch's WeightedRandomSampler, which is designed to handle these kinds of cases. Of course, our problem is a bit more complicated than this.
- Your implementation isn't clear. Maybe you could write it in pseudo-code? I can't tell if there are two for-loops (steps 1 and 5) or just one (step 1).
- Your implementation is extremely inefficient. Instead of loading a single patch from disk, you have to load a much bigger patch from disk and choose a smaller patch within that. For only a 2x2 grid cell, this means 4x as much I/O. For VectorDatasets, you'll also need to rasterize 4x as much area. Not saying I have better ideas, just pointing out that this may be prohibitively slow.
- Another option to deal with class imbalance is to weight the accuracy of each class in the loss function. You may find that this is significantly easier to implement (it should already be supported) and computationally efficient. However, I don't know whether dealing with class imbalance on the sampling side or on the loss side is better.
from torchgeo.
Thanks for you feedback.
Indeed, my initial thoughts have a reduced scope to reduce complexity while figuring it out. I do think a general approach for 2+ classes and even raster & raster
IntersectionDataset
s can be extrapolated.
Also, as I currently have a use case for raster & vector
with binary classes, it is easer for me to get to a proof-of-concept.
I agree that we should draw inspiration from their API and get to a similar usage.
There is one loop like before, where my initial point 1. and 2. would be inside the loop for BalancedRandomSampler
, and outside of the loop for BalancedRandomBatchSampler
.
Here we have a misunderstanding. No more pixels (raster-data) will be loaded than today. My suggested implementation is based on spatial operations on vector data (not rasterised masks). So the rasters will be loaded the same way as today, and after the choice of the sample. Will try to illustrate it in the below example.
It is possible. But for sparse positives, e.g. the label vector-file reprecents a diagonal river across the image, the IntersectionDataset hit
(rectified bounds of intersection) would contain mainly negative pixels, meaning low likelihood of randomly selecting a positive. So randomly selecting samples from it would be e.g. 99% negative and 1% positive samples. Dealing with unbalance in the loss-function only would require to train for a very long time before the model has seen "enough" positives, meaning many more epochs.
Here is the main parts of our implementation. It is tested and works, but I did not include the changes in the init methods of Datasets.
This version has one difference from my initial description, where point 2. and 4. is not implemented. Instead of creating a grid and assign positive and negative contents to those, it randomly generates bounds for the new sample, checks if it contains labels (vector operation), and tries again if it does not (like in #1881). Wether to yield a positive or negative sample is determined by keeping track of how many of each has been yielded so far, and comparing it with the user specified pos_neg_frac
.
class IntersectionDataset(GeoDataset):
def _merge_dataset_indices(self) -> None:
"""Create a new R-tree out of the individual indices from two datasets."""
i = 0
ds1, ds2 = self.datasets
# Assuming ds1 is the RasterDataset
for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True):
# Assuming ds2 is the (positive) label dataset (VectorDataset or RasterDataset)
for hit2 in ds2.index.intersection(hit1.bounds, objects=True):
box1 = BoundingBox(*hit1.bounds)
box2 = BoundingBox(*hit2.bounds)
box_intersection = box1 & box2
# Create vector-geometry of these intersection bounds
bounds = shapely.geometry.box(
minx=box_intersection.minx,
miny=box_intersection.miny,
maxx=box_intersection.maxx,
maxy=box_intersection.maxy,
)
# Raster files may contain nodata. `valid_footprint` have been calculated
# for each file and been added to the rtree-index for the RasterDataset.
# See PR #1881 for discarding nodata samples.
# __getitem__ does already read across all files overlapping the the choosen sample,
# so we need to collect valid footprints for all of them.
# For these bounds, gather all valid_footprints for all (raster) files
# as the current hit's file may contain coverage from other files.
ds1_footprints = [
other.object["valid_footprint"]
for other in ds1.index.intersection(
tuple(box_intersection), objects=True
)
]
# Merge these footprints to one vector feature (MultiPolygon)
ds1_footprint = shapely.unary_union(ds1_footprints)
# Crop it to the bounds for the rtee-index hit
ds1_footprint_in_bounds = shapely.intersection(
ds1_footprint, bounds
)
# Do the same thing for the other dataset
# For VectorDataset this is the actual vector features (not rasterized)
# representing the class coverage. Lets call this valid_fotprint too.
# This too, have been added to the rtree-index in the VectorDataset init
# Similar to above, multiple files may be covered by the this bounds, so
# need to merge their footprints.
ds2_footprints = [
other.object["valid_footprint"]
for other in ds2.index.intersection(
tuple(box_intersection), objects=True
)
]
ds2_footprint = shapely.unary_union(ds2_footprints)
ds2_footprint_in_bounds = shapely.intersection(
ds2_footprint, bounds
)
# Here we made the choice to only include bounds where both datasets
# have coverage, effectively disregarding negative-only rtree-indexes.
if ds1_footprint_in_bounds.is_empty or ds2_footprint_in_bounds.is_empty:
continue
# The intersection of the dataset's footprint's gives us the polygon of positives regions
positive_polygon = shapely.intersection(ds1_footprint_in_bounds, ds2_footprint_in_bounds)
# The difference gives us the polygon of the negative regions
# assuming ds1 is the raster and ds2 is the labels.
# Adding these to the new index
negative_polygon = ds1_footprint_in_bounds - positive_polygon
self.index.insert(
i,
tuple(box_intersection),
{
"positive_polygon": positive_polygon,
"negative_polygon": negative_polygon,
},
)
i += 1
if i == 0:
raise RuntimeError("Datasets have no spatiotemporal intersection")
class RandomGeoSampler(GeoSampler):
def __init__(
self,
dataset: GeoDataset,
size: Union[tuple[float, float], float],
length: Optional[int],
roi: Optional[BoundingBox] = None,
units: Units = Units.PIXELS,
exclude_nodata_samples: bool = False,
max_retries: int = 50_000,
pos_neg_frac: float = 0.5,
) -> None:
"""
Args:
exclude_nodata_samples: will ensure that samples are not outside of the
footprint of the valid pixel. No-data regions may occur due to
re-projection or inherit no-data regions in rasters.
max_retries: is used when exclude_nodata_samples are True. Is a safe-guard
for infinite loops in case the nodata-mask of the raster is wrong.
pos_neg_frac: fraction of positive samples
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.exclude_nodata_samples = exclude_nodata_samples
self.max_retries = max_retries
self.pos_neg_frac = pos_neg_frac
self.pos_sampled = 0
self.neg_sampled = 0
# Rest of __init__ as is
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
Returns:
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
for _ in range(len(self)):
# Choose a random tile, weighted by area
idx = torch.multinomial(self.areas, 1)
hit = self.hits[idx]
bounds = BoundingBox(*hit.bounds)
# New code, balancing positive and negative samples
# by checking that they overlap with the corresponding polygons
# Counts how many of each has been sampled already,
# and decides what should be sampled next.
if (
self.pos_sampled == 0
or (self.pos_sampled / (self.pos_sampled + self.neg_sampled))
< self.pos_neg_frac
):
footprint = hit.object["positive_polygon"]
spatial_operator = shapely.overlaps # ensures some positive
self.pos_sampled += 1
else:
footprint = hit.object["negative_polygon"]
spatial_operator = shapely.within # ensures no positive
self.neg_sampled += 1
# Method implemented in #1881
# https://github.com/microsoft/torchgeo/blob/bbb21d5adb91c5e9fabf1efc83507515532f94b9/torchgeo/samplers/utils.py#L81
yield get_random_bounding_box_check_valid_overlap(
bounds=bounds,
size=self.size,
res=self.res,
valid_footprint=footprint,
max_retries=self.max_retries,
spatial_operator=spatial_operator,
)
from torchgeo.
Related Issues (20)
- Incompatible image size with RandomGeoSampler HOT 3
- Easier way to use Data Processing steps outside of datamodule HOT 4
- Benchmarking of all pre-trained weights HOT 4
- Add instructions on downloading the DeepGlobeLandCover dataset HOT 5
- The new lightly release breaks BaseTask with timm imports HOT 5
- SSL Weight Decay HOT 6
- Migrate from Radiant MLHub to Source Cooperative HOT 16
- Datamodule augmentation defaults HOT 8
- NCCM checksum error HOT 6
- Support additional SatlasPretrain models. HOT 6
- Document significance of macro vs micro averaging HOT 3
- Add support for Lightning Streaming Dataset HOT 14
- OSCDDataModule initialises with batch_size 1, ignoring the configured batch_size HOT 4
- Add `ignore_index` support for Jaccard Loss HOT 1
- Unpin torch, use a min or range? HOT 4
- trainers.segmentation JaccardLoss receiving num_classes, should be a List[int]? HOT 8
- GeoDataset: non-deterministic behavior HOT 5
- Sentinel 2 dataset can't see files downloaded from Copernicus Browser - filename doesn't fit regex HOT 1
- Errors & improvements in Metrics descriptions HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchgeo.