Comments (5)
If by "fix", you mean allow it to be changed, then you can add an argument to the seq2seq
class which will take in an integer and set the trg
tensor length equal to that specified integer.
from pytorch-seq2seq.
No, i mean why you fixe 20 in this case. sorry for the unclear question.
from pytorch-seq2seq.
That 20 represents the maximum number of tokens in the inferred translation. If you're expecting your translated phrases to be longer than 20 tokens then this should be increased.
from pytorch-seq2seq.
yeah. that's what i mean. we can't know before inference step what would be the maximum number of tokens in the inferred translation. I wonder if we can put like 300 to be sure but it will affect the inference time as well. It's better if we can find another way to do
from pytorch-seq2seq.
One possible way to do it is to set an arbitrary maximum length but within the decoding loop we would repeatedly check if the decoder has output the <eos>
token, at which point we break
out of the loop and return what the decoder has output so far.
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, sos_idx, eos_idx, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.sos_idx = sos_idx
self.eos_idx = eos_idx
self.device = device
def create_mask(self, src_len):
max_len = src_len.max()
idxs = torch.arange(0,max_len).to(src_len.device)
mask = (idxs<src_len.unsqueeze(1)).float()
return mask
def forward(self, src, src_len, trg, teacher_forcing_ratio=0.5):
#src = [src sent len, batch size]
#src_len = [batch size]
#trg = [trg sent len, batch size]
#teacher_forcing_ratio is probability to use teacher forcing
#e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
if trg is None:
inference = True
assert teacher_forcing_ratio == 0, "Must be zero during inference"
trg = torch.zeros((100, src.shape[1]), dtype=torch.long).fill_(self.sos_idx).to(src.device)
else:
inference = False
batch_size = src.shape[1]
max_len = trg.shape[0]
trg_vocab_size = self.decoder.output_dim
#tensor to store decoder outputs
outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)
#tensor to store attention
attentions = torch.zeros(max_len, batch_size, src.shape[0]).to(self.device)
#encoder_outputs is all hidden states of the input sequence, back and forwards
#hidden is the final forward and backward hidden states, passed through a linear layer
encoder_outputs, hidden = self.encoder(src, src_len)
#first input to the decoder is the <sos> tokens
output = trg[0,:]
mask = self.create_mask(src_len)
#mask = [batch size, src sent len]
for t in range(1, max_len):
output, hidden, attention = self.decoder(output, hidden, encoder_outputs, mask)
outputs[t] = output
attentions[t] = attention
teacher_force = random.random() < teacher_forcing_ratio
top1 = output.max(1)[1]
output = (trg[t] if teacher_force else top1)
if inference and output.item() == self.eos_idx:
return outputs[:t], attentions[:t]
return outputs, attentions
It's a bit of an ugly hack and will only work if you're inferring one sentence at a time.
from pytorch-seq2seq.
Related Issues (20)
- Thank you! HOT 1
- [Bug] Tranformer Seq2Seq Have Wrong Inputs! HOT 2
- Custom Text Dataset HOT 6
- Question
- torchtext recent version (0.12.0) doesn't support Field, BucketIterator HOT 4
- Question about how to resolve the out of vocabulary problem during encoding and decoding in tutorial 1
- Possible Inaccuracies in training script
- Tutorial 6: [Attention is All You need] Different output at different batch size during Inference
- Question about changing params init from xavier to kaiming
- Transformer ScaledDotProductAttention energy value on 16-bit Precision. HOT 3
- Using pretrained BERT embedding
- Why using tanh function HOT 3
- How do you make this work on android?
- Notebook 1 <eos> problem. HOT 2
- no module named 'torchtext.legacy' HOT 2
- import
- possible opposite explanation of hidden compared to output in notebook #3
- Seq2seq: Input not matching Output (and big thanks)
- How to change seq2seq to graph2seq
- Incorrect German Translation
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-seq2seq.