Giter VIP home page Giter VIP logo

Comments (3)

SunQpark avatar SunQpark commented on August 25, 2024

Here is quick implementation of iteration based training which replaces line 51 to 86 of the trainer.py.

        batch_per_epoch = 10000
        for batch_idx, (data, target) in enumerate(itertools.cycle(self.data_loader)):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.loss(output, target)
            loss.backward()
            self.optimizer.step()

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(...)
            
            # validation
            if batch_idx % batch_per_epoch == 0 and self.do_validation:               
                log = {
                    'loss': total_loss / len(self.data_loader),
                    'metrics': (total_metrics / len(self.data_loader)).tolist()
                }
                val_log = self._valid_epoch(epoch)
                log = {**log, **val_log}
                
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                return log

As you may have noticed, structure of training loop differs much from that of epoch based training. I'm currently not working on iteration based training, because I could not find a clean and simple enough way to support both iteration and epoch based training.

Moreover, since pytorch dataset and data_loader is basically finite sequence, we need to make the data_loader loops infinitely. (I used itertools.cycle here for that, but this is not correct since the order of sampled batches is fixed now.) As I know, best way for doing this is implementing custom data loader as something like generator object, which I think is too much to have in this simple mnist example.

from pytorch-template.

ag14774 avatar ag14774 commented on August 25, 2024

Here is how I implemented it. I added a step(self, epoch) function in BaseDataLoader. This step() can be used to store some internal state in the dataloader object so in each epoch it does something different. A simple example would be a dataset that implements __getitem__(self, key) as follows:

def __getitem__(self, key):
    if key >= len(self):
        raise Exception()
    np.random.seed(key + self.seed_offset)
    return np.uniform.random(0, 100)

So let's say we have 5000 samples. dataset[i] will seed a generator with i and output a random number. I defined step in the dataloader to be:

def step(self, epoch):
    self.dataset.seed_offset = epoch*len(self.dataset)

Then before the training loop in trainer.py we call self.data_loader.step(epoch). In the config file you set epochs to something very high and that's it. This is of course an example but step() can be used to do something else in the dataloader. So just an abstractmethod and leave it up to the person implementing the DataLoader to override it.

from pytorch-template.

SunQpark avatar SunQpark commented on August 25, 2024

I added another implementation of iteration-based training in my PR #53.
Can you check if this PR works for your case, @ag14774??

from pytorch-template.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.