Giter VIP home page Giter VIP logo

Comments (2)

HKervadec avatar HKervadec commented on August 16, 2024 2

Hi,

I think the issue is that mask (which is already one-hot encoded) goes into the dist_map_transform dataloader function, while you could send it directly to one_hot2dist. For reference, here is what it does:

def dist_map_transform(resolution: Tuple[float, ...], K: int) -> Callable[[D], Tensor]:
        return transforms.Compose([
                gt_transform(resolution, K),
                lambda t: t.cpu().numpy(),
                partial(one_hot2dist, resolution=resolution),
                lambda nd: torch.tensor(nd, dtype=torch.float32)
        ])


def gt_transform(resolution: Tuple[float, ...], K: int) -> Callable[[D], Tensor]:
        return transforms.Compose([
                lambda img: np.array(img)[...],
                lambda nd: torch.tensor(nd, dtype=torch.int64)[None, ...],  # Add one dimension to simulate batch
                partial(class2one_hot, K=K),
                itemgetter(0)  # Then pop the element to go back to img shape
        ])

So, end-to end it gives:

        transforms.Compose([
                lambda img: np.array(img)[...],
                lambda nd: torch.tensor(nd, dtype=torch.int64)[None, ...],  # Add one dimension to simulate batch
                partial(class2one_hot, K=K),
                itemgetter(0)  # Then pop the element to go back to img shape
                lambda t: t.cpu().numpy(),
                partial(one_hot2dist, resolution=resolution),
                lambda nd: torch.tensor(nd, dtype=torch.float32)
        ])

What it means is (I believe) you are doing a one-hot-encoding of your one-hot encoding first, before sending to the one_hot2dist. This is what I think cause the extra dimensions and breaks some asserts.

The following would probably work (if not, please let me know how):

def preprocess(self, sample):
        #my code  - usual preprocess

        if self.transform is not None:
            sample = self.transform(sample)
        image, mask = sample

        #  loss code: at this point, mask is of shape(3,H,W) and values are 0 or 1
        #  each of the channels represent the pixel class(r,g,b)
        assert one_hot(mask[None, ...])  # I create an extra axis because one_hot is made to check a whole batch
        
        dist_map_tensor = one_hot2dist(mask, resolution=[1, 1, 1])

        return image, mask, dist_map_tensor

from boundary-loss.

gr33n1 avatar gr33n1 commented on August 16, 2024

Thank you

from boundary-loss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.