Giter VIP home page Giter VIP logo

Comments (5)

bentrevett avatar bentrevett commented on May 17, 2024

If by "fix", you mean allow it to be changed, then you can add an argument to the seq2seq class which will take in an integer and set the trg tensor length equal to that specified integer.

from pytorch-seq2seq.

CuriousDeepLearner avatar CuriousDeepLearner commented on May 17, 2024

No, i mean why you fixe 20 in this case. sorry for the unclear question.

from pytorch-seq2seq.

bentrevett avatar bentrevett commented on May 17, 2024

That 20 represents the maximum number of tokens in the inferred translation. If you're expecting your translated phrases to be longer than 20 tokens then this should be increased.

from pytorch-seq2seq.

CuriousDeepLearner avatar CuriousDeepLearner commented on May 17, 2024

yeah. that's what i mean. we can't know before inference step what would be the maximum number of tokens in the inferred translation. I wonder if we can put like 300 to be sure but it will affect the inference time as well. It's better if we can find another way to do

from pytorch-seq2seq.

bentrevett avatar bentrevett commented on May 17, 2024

One possible way to do it is to set an arbitrary maximum length but within the decoding loop we would repeatedly check if the decoder has output the <eos> token, at which point we break out of the loop and return what the decoder has output so far.

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, sos_idx, eos_idx, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.sos_idx = sos_idx
        self.eos_idx = eos_idx
        self.device = device
        
    def create_mask(self, src_len):
        max_len = src_len.max()
        idxs = torch.arange(0,max_len).to(src_len.device)
        mask = (idxs<src_len.unsqueeze(1)).float()
        return mask
        
    def forward(self, src, src_len, trg, teacher_forcing_ratio=0.5):
        
        #src = [src sent len, batch size]
        #src_len = [batch size]
        #trg = [trg sent len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
        
        if trg is None:
            inference = True
            assert teacher_forcing_ratio == 0, "Must be zero during inference"
            trg = torch.zeros((100, src.shape[1]), dtype=torch.long).fill_(self.sos_idx).to(src.device)
        else:
            inference = False
            
        batch_size = src.shape[1]
        max_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        #tensor to store decoder outputs
        outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)
        
        #tensor to store attention
        attentions = torch.zeros(max_len, batch_size, src.shape[0]).to(self.device)
        
        #encoder_outputs is all hidden states of the input sequence, back and forwards
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src, src_len)
                
        #first input to the decoder is the <sos> tokens
        output = trg[0,:]
        
        mask = self.create_mask(src_len)
                
        #mask = [batch size, src sent len]
                
        for t in range(1, max_len):
            output, hidden, attention = self.decoder(output, hidden, encoder_outputs, mask)
            outputs[t] = output
            attentions[t] = attention
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.max(1)[1]
            output = (trg[t] if teacher_force else top1)
            if inference and output.item() == self.eos_idx:
                return outputs[:t], attentions[:t]
            
        return outputs, attentions

It's a bit of an ugly hack and will only work if you're inferring one sentence at a time.

from pytorch-seq2seq.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.