Giter VIP home page Giter VIP logo

didactic-meme's Introduction

Didactic meme

A modelling suite with extra focus on pytorch

  • Speed up the modelling process
  • Increase traceability of trained models
  • Easier model comparison and highscores
  • Visualize model predictions
  • Easily expose model via web api
  • Hash train, validate, and test datasets separately

Planned features:

  • Model config
  • Standard training loop
  • Setup loggers
  • Command line config and training
  • Visualize blackbox solution
  • Web api helper
  • Training helper functions
  • Intra-epoch logging

TODO

  • Reconsider how api and helpers work
  • Save models by epoch
  • Score models and list highscore
  • Tensorboardx
  • Need to handle custom pre-processing
  • Continue training/tuning from checkpoint (hash initial model)
  • Tabular logging like skorch

Usage

Draft of how the library should be used.

Overly simplistic training

No custom code or metrics

def train(train_ds, validate_ds, config):
    # create train_loader, validate_loader, model, and optimizer
    model_suite.Trainer(train_loader, validate_loader, model, optimizer, config).train()

Standard training loop

Custom metric

class Trainer(model_suite.Trainer):
    def train_batch(self, features, labels):
        log_prob = model(features).log_prob(labels)
        loss = -log_prob.mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return dict(
            log_prob=log_prob.sum(),
            accuracy=get_accuracy(features, labels, self.model).mean(),
        )

    def evaluate_batch(self, features, labels):
        return dict(
            log_prob=model(features).log_prob(labels).sum(),
            accuracy=get_accuracy(features, labels, self.model).mean(),
        )

    def summarize_epoch(train_results, validate_results, epoch):
        train_results, validate_results = self.zip_dicts(train_results), self.zip_dicts(validate_results)

        train_results = {key: sum(value)/len(self.train_loader.dataset) for key, value in train_results.items()}
        validate_results = {key: sum(value)/len(self.validate_loader.dataset) for key, value in validate_results.items()}

        epoch_results = self.merge_add_level(train=train_results, validate=validate_results)
        self.add_tb_scalars(epoch_results, epoch)

        return {key: epoch_results[key] for key in ['log_prob']}

def train(train_ds, validate_ds, config):
    # create train_loader, validate_loader, model, and optimizer
    Trainer(train_loader, validate_loader, model, optimizer, config).train()

Generative Adverserial Model

class Trainer(model_suite.Trainer):
    def train_batch(self, features, boards):
        discriminator_log_prob = self.model.get_discriminator_log_prob(features, boards)
        generator_log_prob = self.model.get_generator_log_prob(features)

        discriminator_loss = -discriminator_log_prob.mean()
        self.optimizer.discriminator_optimizer.zero_grad()
        discriminator_loss.backward()
        self.optimizer.discriminator_optimizer.step()

        generator_loss = -generator_log_prob.mean()
        self.optimizer.generator_optimizer.zero_grad()
        generator_loss.backward()
        self.optimizer.generator_optimizer.step()

        return dict(
            discriminator_log_prob=discriminator_log_prob.sum(),
            generator_log_prob=generator_log_prob.sum(),
        )

    def evaluate_batch(self, features, boards):
        return dict(
            discriminator_log_prob=self.model.get_discriminator_log_prob(features, boards).sum(),
            generator_log_prob=self.model.get_generator_log_prob(features).sum(),
        )


def train(train_ds, validate_ds, config):
    # setup data loaders, model and optimizers

    optimizer = model_suite.MultipleOptimizers(
        generator_optimizer=...,
        discriminator_optimizer=...,
    )

    Trainer(train_loader, validate_loader, model, optimizer, config).train()

didactic-meme's People

Contributors

samedii avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.