Giter VIP home page Giter VIP logo

Comments (1)

salvacarrion avatar salvacarrion commented on June 8, 2024

I have written this piece of code. Is it not a clean solution, but it works.

import random
from torch.utils.data.sampler import BatchSampler, RandomSampler, SubsetRandomSampler
from torchnlp.utils import identity


class MaxTokensBatchSampler(BatchSampler):

    def __init__(self,
                 sampler,
                 batch_size,
                 max_tokens,
                 drop_last,
                 sort_key=identity,
                 bucket_size_multiplier=100,
                 shuffle=True):
        super().__init__(sampler, batch_size, drop_last)
        self.max_tokens = max_tokens
        self.sort_key = sort_key
        self.bucket_size_multiplier = bucket_size_multiplier
        self.shuffle = shuffle

        # Not a clean solution
        self.bucket_batches = []
        self._build_buckets()

    def __iter__(self):
        # Iterate over buckets
        for batches, batch_sizes in self.bucket_batches:
            # Shuffle bucket-batch order
            batches = SubsetRandomSampler(batches) if self.shuffle else batches
            for batch in batches:
                if self.shuffle:  # Shuffle inner batch
                    random.shuffle(batch)
                yield batch  # Batch indexes [sent1_idx, sent2_idx,...]

    def __len__(self):
        return sum([len(x[0]) for x in self.bucket_batches])

    def _build_buckets(self):
        # Randomize samples
        tmp_sampler = RandomSampler(self.sampler) if self.shuffle else self.sampler

        # Split samples in N batches (or "buckets")
        tmp_sampler = BatchSampler(tmp_sampler, min(self.batch_size * self.bucket_size_multiplier, len(self.sampler)),
                                   False)

        # Sort samples
        self.bucket_batches = []
        for bucket in tmp_sampler:
            bucket_sorted = sorted([(i, self.sort_key(i)) for i in bucket], key=lambda x: x[1])

            # Create batches constrained
            batches = []
            batch_sizes = []

            last_batch = []
            last_batch_size = 0
            for i, (sample_i, length_i) in enumerate(bucket_sorted):
                if (last_batch_size + length_i) < self.max_tokens:
                    last_batch.append(sample_i)
                    last_batch_size += length_i
                else:
                    # Add batch
                    batches.append(last_batch)
                    batch_sizes.append(last_batch_size)

                    # Add new sample
                    last_batch = [sample_i]
                    last_batch_size = length_i

            # Add last batch
            batches.append(last_batch)
            batch_sizes.append(last_batch_size)

            # Add bucket batches
            self.bucket_batches.append((batches, batch_sizes))

It works as follows:

  1. Randomize all samples/sentences: [0,1,2,3,...n] => [6, 12, 60,... , 31]
  2. Split samples into buckets => [6, 12, 60,...], [92, 1, 52,... , 24], [95, 234, 33,... , 31]
  3. Sort in-bucket by sentence lengths
  4. Shuffle batch-orders in butckets: bucket1 (batch1, batch2,...batchN => batch5, batch10,... batch3), butcket2...
  5. Shuffle batch sentences: batch1: [23, 51, 12...] => [391, 2, 33,...]

You can call using:

train_sampler = MaxTokensBatchSampler(SequentialSampler(train_ds), shuffle=True, batch_size=BATCH_SIZE, max_tokens=MAX_TOKENS, drop_last=False, sort_key=lambda i: len(train_ds.datasets.iloc[i]["src"].split()))
    val_sampler = MaxTokensBatchSampler(SequentialSampler(val_ds), shuffle=False, batch_size=BATCH_SIZE, max_tokens=MAX_TOKENS, drop_last=False, sort_key=lambda i: len(val_ds.datasets.iloc[i]["src"].split()))
 

train_ds and val_ds are torch Dataset classes: (class TranslationDataset(Dataset):)

from pytorch-nlp.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.