Giter VIP home page Giter VIP logo

Comments (2)

adamjstewart avatar adamjstewart commented on June 5, 2024

A few comments on this.

  1. Your initial proposal seems to be limited to binary classes and vector datasets. However, class imbalance is a problem faced by all 2+ class raster and vector datasets. It would be ideal to have an implementation that works for all of these.
  2. We should draw inspiration from (or at least point out) PyTorch's WeightedRandomSampler, which is designed to handle these kinds of cases. Of course, our problem is a bit more complicated than this.
  3. Your implementation isn't clear. Maybe you could write it in pseudo-code? I can't tell if there are two for-loops (steps 1 and 5) or just one (step 1).
  4. Your implementation is extremely inefficient. Instead of loading a single patch from disk, you have to load a much bigger patch from disk and choose a smaller patch within that. For only a 2x2 grid cell, this means 4x as much I/O. For VectorDatasets, you'll also need to rasterize 4x as much area. Not saying I have better ideas, just pointing out that this may be prohibitively slow.
  5. Another option to deal with class imbalance is to weight the accuracy of each class in the loss function. You may find that this is significantly easier to implement (it should already be supported) and computationally efficient. However, I don't know whether dealing with class imbalance on the sampling side or on the loss side is better.

from torchgeo.

adriantre avatar adriantre commented on June 5, 2024

Thanks for you feedback.

Indeed, my initial thoughts have a reduced scope to reduce complexity while figuring it out. I do think a general approach for 2+ classes and even raster & raster IntersectionDatasets can be extrapolated.

Also, as I currently have a use case for raster & vector with binary classes, it is easer for me to get to a proof-of-concept.

I agree that we should draw inspiration from their API and get to a similar usage.

There is one loop like before, where my initial point 1. and 2. would be inside the loop for BalancedRandomSampler, and outside of the loop for BalancedRandomBatchSampler.

Here we have a misunderstanding. No more pixels (raster-data) will be loaded than today. My suggested implementation is based on spatial operations on vector data (not rasterised masks). So the rasters will be loaded the same way as today, and after the choice of the sample. Will try to illustrate it in the below example.

It is possible. But for sparse positives, e.g. the label vector-file reprecents a diagonal river across the image, the IntersectionDataset hit (rectified bounds of intersection) would contain mainly negative pixels, meaning low likelihood of randomly selecting a positive. So randomly selecting samples from it would be e.g. 99% negative and 1% positive samples. Dealing with unbalance in the loss-function only would require to train for a very long time before the model has seen "enough" positives, meaning many more epochs.

FqYnPuMWIAA8oMA

Here is the main parts of our implementation. It is tested and works, but I did not include the changes in the init methods of Datasets.

This version has one difference from my initial description, where point 2. and 4. is not implemented. Instead of creating a grid and assign positive and negative contents to those, it randomly generates bounds for the new sample, checks if it contains labels (vector operation), and tries again if it does not (like in #1881). Wether to yield a positive or negative sample is determined by keeping track of how many of each has been yielded so far, and comparing it with the user specified pos_neg_frac.

class IntersectionDataset(GeoDataset):
    def _merge_dataset_indices(self) -> None:
        """Create a new R-tree out of the individual indices from two datasets."""
        i = 0
        ds1, ds2 = self.datasets
        # Assuming ds1 is the RasterDataset
        for hit1 in ds1.index.intersection(ds1.index.bounds, objects=True):
            # Assuming ds2 is the (positive) label dataset (VectorDataset or RasterDataset)
            for hit2 in ds2.index.intersection(hit1.bounds, objects=True):
                box1 = BoundingBox(*hit1.bounds)
                box2 = BoundingBox(*hit2.bounds)
                box_intersection = box1 & box2

                # Create vector-geometry of these intersection bounds
                bounds = shapely.geometry.box(
                    minx=box_intersection.minx,
                    miny=box_intersection.miny,
                    maxx=box_intersection.maxx,
                    maxy=box_intersection.maxy,
                )
                
                # Raster files may contain nodata. `valid_footprint` have been calculated
                # for each file and been added to the rtree-index for the RasterDataset.
                # See PR #1881 for discarding nodata samples.
                
                # __getitem__ does already read across all files overlapping the the choosen sample,
                # so we need to collect valid footprints for all of them.
                # For these bounds, gather all valid_footprints for all (raster) files
                # as the current hit's file may contain coverage from other files.
                ds1_footprints = [
                    other.object["valid_footprint"]
                    for other in ds1.index.intersection(
                        tuple(box_intersection), objects=True
                    )
                ]

                # Merge these footprints to one vector feature (MultiPolygon)
                ds1_footprint = shapely.unary_union(ds1_footprints)

                # Crop it to the bounds for the rtee-index hit
                ds1_footprint_in_bounds = shapely.intersection(
                    ds1_footprint, bounds
                )

                # Do the same thing for the other dataset
                # For VectorDataset this is the actual vector features (not rasterized)
                # representing the class coverage. Lets call this valid_fotprint too.
                # This too, have been added to the rtree-index in the VectorDataset init
                # Similar to above, multiple files may be covered by the this bounds, so
                # need to merge their footprints.
                ds2_footprints = [
                    other.object["valid_footprint"]
                    for other in ds2.index.intersection(
                        tuple(box_intersection), objects=True
                    )
                ]
                ds2_footprint = shapely.unary_union(ds2_footprints)
                ds2_footprint_in_bounds = shapely.intersection(
                    ds2_footprint, bounds
                )

                # Here we made the choice to only include bounds where both datasets
                # have coverage, effectively disregarding negative-only rtree-indexes.
                if ds1_footprint_in_bounds.is_empty or ds2_footprint_in_bounds.is_empty:
                    continue
                
                # The intersection of the dataset's footprint's gives us the polygon of positives regions
                positive_polygon = shapely.intersection(ds1_footprint_in_bounds, ds2_footprint_in_bounds)

                # The difference gives us the polygon of the negative regions
                # assuming ds1 is the raster and ds2 is the labels.
                # Adding these to the new index
                negative_polygon = ds1_footprint_in_bounds - positive_polygon
                self.index.insert(
                    i,
                    tuple(box_intersection),
                    {
                        "positive_polygon": positive_polygon,
                        "negative_polygon": negative_polygon,
                    },
                )
                i += 1

        if i == 0:
            raise RuntimeError("Datasets have no spatiotemporal intersection")


class RandomGeoSampler(GeoSampler):
    def __init__(
        self,
        dataset: GeoDataset,
        size: Union[tuple[float, float], float],
        length: Optional[int],
        roi: Optional[BoundingBox] = None,
        units: Units = Units.PIXELS,
        exclude_nodata_samples: bool = False,
        max_retries: int = 50_000,
        pos_neg_frac: float = 0.5,
    ) -> None:
        """
        Args:
            exclude_nodata_samples: will ensure that samples are not outside of the
                footprint of the valid pixel. No-data regions may occur due to
                re-projection or inherit no-data regions in rasters.
            max_retries: is used when exclude_nodata_samples are True. Is a safe-guard
                for infinite loops in case the nodata-mask of the raster is wrong.
            pos_neg_frac: fraction of positive samples
        """
        super().__init__(dataset, roi)
        self.size = _to_tuple(size)
        self.exclude_nodata_samples = exclude_nodata_samples
        self.max_retries = max_retries
        self.pos_neg_frac = pos_neg_frac
        self.pos_sampled = 0
        self.neg_sampled = 0

        # Rest of __init__ as is

    def __iter__(self) -> Iterator[BoundingBox]:
        """Return the index of a dataset.

        Returns:
            (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
        """
        for _ in range(len(self)):
            # Choose a random tile, weighted by area
            idx = torch.multinomial(self.areas, 1)
            hit = self.hits[idx]
            bounds = BoundingBox(*hit.bounds)

            # New code, balancing positive and negative samples
            # by checking that they overlap with the corresponding polygons
            # Counts how many of each has been sampled already, 
            # and decides what should be sampled next.
            if (
                self.pos_sampled == 0
                or (self.pos_sampled / (self.pos_sampled + self.neg_sampled))
                < self.pos_neg_frac
            ):
                footprint = hit.object["positive_polygon"]
                spatial_operator = shapely.overlaps  # ensures some positive
                self.pos_sampled += 1
            else:
                footprint = hit.object["negative_polygon"]
                spatial_operator = shapely.within  # ensures no positive
                self.neg_sampled += 1

            # Method implemented in #1881
            # https://github.com/microsoft/torchgeo/blob/bbb21d5adb91c5e9fabf1efc83507515532f94b9/torchgeo/samplers/utils.py#L81
            yield get_random_bounding_box_check_valid_overlap(
                bounds=bounds,
                size=self.size,
                res=self.res,
                valid_footprint=footprint,
                max_retries=self.max_retries,
                spatial_operator=spatial_operator,
            )

from torchgeo.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.