Comments (1)
I have written this piece of code. Is it not a clean solution, but it works.
import random
from torch.utils.data.sampler import BatchSampler, RandomSampler, SubsetRandomSampler
from torchnlp.utils import identity
class MaxTokensBatchSampler(BatchSampler):
def __init__(self,
sampler,
batch_size,
max_tokens,
drop_last,
sort_key=identity,
bucket_size_multiplier=100,
shuffle=True):
super().__init__(sampler, batch_size, drop_last)
self.max_tokens = max_tokens
self.sort_key = sort_key
self.bucket_size_multiplier = bucket_size_multiplier
self.shuffle = shuffle
# Not a clean solution
self.bucket_batches = []
self._build_buckets()
def __iter__(self):
# Iterate over buckets
for batches, batch_sizes in self.bucket_batches:
# Shuffle bucket-batch order
batches = SubsetRandomSampler(batches) if self.shuffle else batches
for batch in batches:
if self.shuffle: # Shuffle inner batch
random.shuffle(batch)
yield batch # Batch indexes [sent1_idx, sent2_idx,...]
def __len__(self):
return sum([len(x[0]) for x in self.bucket_batches])
def _build_buckets(self):
# Randomize samples
tmp_sampler = RandomSampler(self.sampler) if self.shuffle else self.sampler
# Split samples in N batches (or "buckets")
tmp_sampler = BatchSampler(tmp_sampler, min(self.batch_size * self.bucket_size_multiplier, len(self.sampler)),
False)
# Sort samples
self.bucket_batches = []
for bucket in tmp_sampler:
bucket_sorted = sorted([(i, self.sort_key(i)) for i in bucket], key=lambda x: x[1])
# Create batches constrained
batches = []
batch_sizes = []
last_batch = []
last_batch_size = 0
for i, (sample_i, length_i) in enumerate(bucket_sorted):
if (last_batch_size + length_i) < self.max_tokens:
last_batch.append(sample_i)
last_batch_size += length_i
else:
# Add batch
batches.append(last_batch)
batch_sizes.append(last_batch_size)
# Add new sample
last_batch = [sample_i]
last_batch_size = length_i
# Add last batch
batches.append(last_batch)
batch_sizes.append(last_batch_size)
# Add bucket batches
self.bucket_batches.append((batches, batch_sizes))
It works as follows:
- Randomize all samples/sentences: [0,1,2,3,...n] => [6, 12, 60,... , 31]
- Split samples into buckets => [6, 12, 60,...], [92, 1, 52,... , 24], [95, 234, 33,... , 31]
- Sort in-bucket by sentence lengths
- Shuffle batch-orders in butckets: bucket1 (batch1, batch2,...batchN => batch5, batch10,... batch3), butcket2...
- Shuffle batch sentences: batch1: [23, 51, 12...] => [391, 2, 33,...]
You can call using:
train_sampler = MaxTokensBatchSampler(SequentialSampler(train_ds), shuffle=True, batch_size=BATCH_SIZE, max_tokens=MAX_TOKENS, drop_last=False, sort_key=lambda i: len(train_ds.datasets.iloc[i]["src"].split()))
val_sampler = MaxTokensBatchSampler(SequentialSampler(val_ds), shuffle=False, batch_size=BATCH_SIZE, max_tokens=MAX_TOKENS, drop_last=False, sort_key=lambda i: len(val_ds.datasets.iloc[i]["src"].split()))
train_ds
andval_ds
are torch Dataset classes: (class TranslationDataset(Dataset):
)
from pytorch-nlp.
Related Issues (20)
- SpacyEncoder: TypeError: __init__() got an unexpected keyword argument 'language' HOT 2
- Simplify `Encoder`: Special Tokens, OOB, Batch Encoding HOT 1
- Could this be improved?
- Fix `fork_rng_wrap`
- handling large-scale datasets with distributed dataloaders for iterative datasets
- Make `check_files` more generic HOT 1
- Wrong number of classes is derived from `label_encoder.vocab_size` HOT 5
- Gating for inputs
- wmt_dataset download failed HOT 2
- Error in SpacyEncoder when language argument is passed
- PackagesNotFoundError in anaconda
- torchnlp ERROR: No matching distribution found for torch==1.0.0
- SpacyWordSplitter: module not found - AllenNLP v1.5.0
- AttributeError: 'WeightDropGRU' object has no attribute '_flat_weights' HOT 2
- Inefficient embedding loading code in README.md HOT 3
- Fix simple typo: donwload -> download
- Why there is a huge result difference between Awd-Lstm Weight_Drop and this one HOT 1
- ONNX support for Encoders HOT 1
- OSError on second usage of FastText() HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-nlp.