Comments (3)
Why are the target lengths:
loss = criterion(output, target[0], target[1])
Used to create a mask on the output:
mask = self._sequence_mask(target[1]).unsqueeze(2)
mask_ = mask.expand_as(input)
from loop.
Fixed this with:
class MaskedMSE(nn.Module):
def __init__(self):
super(MaskedMSE, self).__init__()
self.criterion = nn.MSELoss(size_average=False)
# Taken from
# https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation
@staticmethod
def _sequence_mask(sequence_length, max_len):
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand = Variable(seq_range_expand)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = sequence_length.unsqueeze(1) \
.expand_as(seq_range_expand)
return (seq_range_expand < seq_length_expand).t().float()
def forward(self, input, target, lengths):
max_len = input.size(0)
mask = self._sequence_mask(lengths, max_len).unsqueeze(2)
mask_ = mask.expand_as(input)
self.loss = self.criterion(input*mask_, target*mask_)
self.loss = self.loss / mask.sum()
return self.loss
from loop.
Thanks!
Fixed in master.
from loop.
Related Issues (20)
- Out of memory in validation step HOT 1
- Parameters for dataset in the wild HOT 12
- Main Readme wav files are missing and first instruction doesn't work HOT 1
- Using pre-trained model for new speaker?
- No matching distribution found for phonemizer (from -r scripts/requirements.txt (line 5)) HOT 1
- Issue running install_tts.py to preprocess data HOT 1
- Error running train.py HOT 1
- Error when 'make' HTK-3.4.1 and hts_core.
- Train VCTK dataset for all speakers
- bash scripts/download_tools.sh failed on Mac OS
- ERROR: Failed to find norm file. HOT 7
- ImportError: No module named torch
- How this repo compared to Merlin?
- Understanding feat tensor dimensions HOT 1
- Look like it fails on '!' character.
- Strange fail on "The quick brown fox jumps over the lazy dog."
- TBPTTIter.split_length() error HOT 1
- Batch
- hello world text
- Block on preprocessing
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from loop.