Giter VIP home page Giter VIP logo

pytorch-generative's Introduction

pytorch-generative

pytorch-generative is a Python library which makes generative modeling in PyTorch easier by providing:

  • high quality reference implementations of SOTA generative models
  • useful abstractions of common building blocks found in the literature
  • utilities for training, debugging, and working with Google Colab
  • integration with TensorBoard for easy metrics visualization

To get started, click on one of the links below.

Installation

To install pytorch-generative, clone the repository and install the requirements:

git clone https://www.github.com/EugenHotaj/pytorch-generative
cd pytorch-generative
pip install -r requirements.txt

After installation, run the tests to sanity check that everything works:

python -m unittest discover

Reproducing Results

All our models implement a reproduce function with all the hyperparameters necessary to reproduce the results listed in the supported algorithms section. This makes it very easy to reproduce any results using our training script, for example:

python train.py --model image_gpt --logdir /tmp/run --use-cuda

Training metrics will periodically be logged to TensorBoard for easy visualization. To view these metrics, launch a local TensorBoard server:

tensorboard --logdir /tmp/run

To run the model on a different dataset, with different hyperparameters, etc, simply modify its reproduce function and rerun the commands above.

Google Colab

To use pytorch-generative in Google Colab, clone the repository and move it into the top-level directory:

!git clone https://www.github.com/EugenHotaj/pytorch-generative
!mv pytorch-generative/pytorch_generative pytorch-generative

You can then import pytorch-generative like any other library:

import pytorch_generative as pg_nn
from pytorch_generative import models
...

Example - ImageGPT

Supported models are implemented as PyTorch Modules and are easy to use:

from pytorch_generative import models

... # Data loading code.

model = models.ImageGPT(in_channels=1, out_channels=1, in_size=28)
model(batch)

Alternatively, lower level building blocks in pytorch_generative.nn can be used to write models from scratch. We show how to implement a convolutional ImageGPT model below:

from torch import nn

from pytorch_generative import nn as pg_nn


class TransformerBlock(nn.Module):
  """An ImageGPT Transformer block."""

  def __init__(self, 
               n_channels, 
               n_attention_heads):
    """Initializes a new TransformerBlock instance.
    
    Args:
      n_channels: The number of input and output channels.
      n_attention_heads: The number of attention heads to use.
    """
    super().__init__()
    self._ln1 = pg_nn.NCHWLayerNorm(n_channels)
    self._ln2 = pg_nn.NCHWLayerNorm(n_channels)
    self._attn = pg_nn.CausalAttention(
        in_channels=n_channels,
        embed_channels=n_channels,
        out_channels=n_channels,
        n_heads=n_attention_heads,
        mask_center=False)
    self._out = nn.Sequential(
        nn.Conv2d(
            in_channels=n_channels, 
            out_channels=4*n_channels, 
            kernel_size=1),
        nn.GELU(),
        nn.Conv2d(
            in_channels=4*n_channels, 
            out_channels=n_channels, 
            kernel_size=1))

  def forward(self, x):
    x = x + self._attn(self._ln1(x))
    return x + self._out(self._ln2(x))


class ImageGPT(nn.Module):
  """The ImageGPT Model."""
  
  def __init__(self,       
               in_channels,
               out_channels,
               in_size,
               n_transformer_blocks=8,
               n_attention_heads=4,
               n_embedding_channels=16):
    """Initializes a new ImageGPT instance.
    
    Args:
      in_channels: The number of input channels.
      out_channels: The number of output channels.
      in_size: Size of the input images. Used to create positional encodings.
      n_transformer_blocks: Number of TransformerBlocks to use.
      n_attention_heads: Number of attention heads to use.
      n_embedding_channels: Number of attention embedding channels to use.
    """
    super().__init__()
    self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
    self._input = pg_nn.CausalConv2d(
        mask_center=True,
        in_channels=in_channels,
        out_channels=n_embedding_channels,
        kernel_size=3,
        padding=1)
    self._transformer = nn.Sequential(
        *[TransformerBlock(n_channels=n_embedding_channels,
                         n_attention_heads=n_attention_heads)
          for _ in range(n_transformer_blocks)])
    self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels)
    self._out = nn.Conv2d(in_channels=n_embedding_channels,
                          out_channels=out_channels,
                          kernel_size=1)

  def forward(self, x):
    x = self._input(x + self._pos)
    x = self._transformer(x)
    x = self._ln(x)
    return self._out(x)

Supported Algorithms

pytorch-generative supports the following algorithms.

We train likelihood based models on dynamically Binarized MNIST and report the log likelihood in the tables below.

Autoregressive Models

Algorithm Binarized MNIST (nats) Links
PixelSNAIL 78.61 Code, Paper
ImageGPT 79.17 Code, Paper
Gated PixelCNN 81.50 Code, Paper
PixelCNN 81.45 Code, Paper
MADE 84.87 Code, Paper
NADE 85.65 Code, Paper
FVSBN 96.58 Code, Paper

Variational Autoencoders

NOTE: The results below are the (variational) upper bound on the negative log likelihod (or equivalently, the lower bound on the log likelihod).

Algorithm Binarized MNIST (nats) Links
VD-VAE <= 80.72 Code, Paper
VAE <= 86.77 Code, Paper
BetaVAE N/A Code, Paper
VQ-VAE N/A Code, Paper
VQ-VAE-2 N/A Code, Paper

Normalizing Flows

NOTE: Bits per dimension (bits/dim) can be calculated as (nll / 784 + log(256)) / log(2) where 784 is the MNIST dimension, log(256) accounts for dequantizing pixel values, and log(2.0) converts from natural log to base 2.

Algorithm MNIST (bits/dim) Links
NICE 4.34 Code, Paper

Miscellaneous

Algorithm Links
Mixture Models Code, Wiki
Kernel Density Estimators Code, Wiki
Nerual Style Transfer Code, Blog, Paper
Compositional Pattern Producing Networks Code, Wiki

pytorch-generative's People

Contributors

eugenhotaj avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.