Giter VIP home page Giter VIP logo

Comments (3)

PetrochukM avatar PetrochukM commented on July 17, 2024

Why are the target lengths:
loss = criterion(output, target[0], target[1])

Used to create a mask on the output:

        mask = self._sequence_mask(target[1]).unsqueeze(2)
        mask_ = mask.expand_as(input)

from loop.

PetrochukM avatar PetrochukM commented on July 17, 2024

Fixed this with:

class MaskedMSE(nn.Module):
    def __init__(self):
        super(MaskedMSE, self).__init__()
        self.criterion = nn.MSELoss(size_average=False)

    # Taken from
    # https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation
    @staticmethod
    def _sequence_mask(sequence_length, max_len):
        batch_size = sequence_length.size(0)
        seq_range = torch.arange(0, max_len).long()
        seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
        seq_range_expand = Variable(seq_range_expand)
        if sequence_length.is_cuda:
            seq_range_expand = seq_range_expand.cuda()
        seq_length_expand = sequence_length.unsqueeze(1) \
                                           .expand_as(seq_range_expand)
        return (seq_range_expand < seq_length_expand).t().float()

    def forward(self, input, target, lengths):
        max_len = input.size(0)
        mask = self._sequence_mask(lengths, max_len).unsqueeze(2)
        mask_ = mask.expand_as(input)
        self.loss = self.criterion(input*mask_, target*mask_)
        self.loss = self.loss / mask.sum()
        return self.loss

from loop.

adampolyak avatar adampolyak commented on July 17, 2024

Thanks!
Fixed in master.

from loop.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.