Giter VIP home page Giter VIP logo

autoregressive's Introduction

Build Status

Autoregressive

This repository contains all the necessary PyTorch code, tailored to my presentation, to train and generate data from WaveNet-like autoregressive models.

For presentation purposes, the WaveNet-like models are applied to randomized Fourier series (1D) and MNIST (2D). In the figure below, two WaveNet-like models with different training settings make an n-step prediction on a periodic time-series from the validation dataset.

Advanced functions show how to generate MNIST images and how to estimate the MNIST digit class (progressively) p(y=class|x) from observed pixels using a conditional WaveNet p(x|y=class) and Bayes rule. Left: sampled MNIST digits, right: progressive class estimates as more pixels are observed.

Note, this library does not implement (Gated) PixelCNNs, but unrolls images for the purpose of processing in WaveNet architectures. This works surprisingly well.

Features

Currently the following features are implemented

  • WaveNet architecture and training as proposed in (oord2016wavenet)
  • Conditioning support (oord2016wavenet)
  • Fast generation based on (paine2016fast)
  • Fully differentiable n-step unrolling in training (heindl2021autoreg)
  • 2D image generation, completion, classification, and progressive classification support based on MNIST dataset
  • A randomized Fourier dataset

Presentation

A detailed presentation with theoretical background, architectural considerations and experiments can be found below.

The presentation source as well as all generated images are public domain. In case you find them useful, please leave a citation (see References below). All presentation sources can be found in etc/presentation. The presentation is written in markdown using Marp, graph diagrams are created using yEd.

If you spot errors or if case you have suggestions for improvements, please let me know by opening an issue.

Installation

To install run,

pip install git+https://github.com/cheind/autoregressive.git#egg=autoregressive[dev]

which requires Python 3.9 and a recent PyTorch > 1.9

Usage

The library comes with a set of pre-trained models in models/. The following commands use those models to make various predictions. Many listed commands come with additional parameters; use --help to get additional information.

1D Fourier series

Sample new signals from scratch

python -m autoregressive.scripts.wavenet_signals sample --config "models/fseries_q127/config.yaml" --ckpt "models/fseries_q127/xxxxxx.ckpt" --condition 4 --horizon 1000

The default models conditions on the periodicity of the signal. For the pre-trained model the value range is int: [0..4], corresponding to periods of 5-10secs.


Predict the shape of partially observable curves.

python -m autoregressive.scripts.wavenet_signals predict --config "models/fseries_q127/config.yaml" --ckpt "models/fseries_q127/xxxxxx.ckpt" --horizon 1500 --num_observed 50 --num_trajectories 20 --num_curves 1 --show_confidence true

2D MNIST

To sample from the class-conditional model

python -m autoregressive.scripts.wavenet_mnist sample --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt"

Generate images conditioned on the digit class and observed pixels.

python -m autoregressive.scripts.wavenet_mnist predict --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt" 

To perform classification

python -m autoregressive.scripts.wavenet_mnist classify --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt"

Train

To train / reproduce a model

python -m autoregressive.scripts.train fit --config "models/mnist_q2/config.yaml"

Progress is logged to Tensorboard

tensorboard --logdir lightning_logs

To generate a training configuration file for a specific dataset use

python -m autoregressive.scripts.train fit --data autoregressive.datasets.FSeriesDataModule --print_config > fseries_config.yaml

Test

To run the tests

pytest

References

@misc{heindl2021autoreg, 
  title={Autoregressive Models}, 
  journal={PROFACTOR Journal Club}, 
  author={Heindl, Christoph},
  year={2021},
  howpublished={\url{https://github.com/cheind/autoregressive}}
}

@article{oord2016wavenet,
  title={Wavenet: A generative model for raw audio},
  author={Oord, Aaron van den and Dieleman, Sander and Zen, Heiga and Simonyan, Karen and Vinyals, Oriol and Graves, Alex and Kalchbrenner, Nal and Senior, Andrew and Kavukcuoglu, Koray},
  journal={arXiv preprint arXiv:1609.03499},
  year={2016}
}

@article{paine2016fast,
  title={Fast wavenet generation algorithm},
  author={Paine, Tom Le and Khorrami, Pooya and Chang, Shiyu and Zhang, Yang and Ramachandran, Prajit and Hasegawa-Johnson, Mark A and Huang, Thomas S},
  journal={arXiv preprint arXiv:1611.09482},
  year={2016}
}

@article{oord2016conditional,
  title={Conditional image generation with pixelcnn decoders},
  author={Oord, Aaron van den and Kalchbrenner, Nal and Vinyals, Oriol and Espeholt, Lasse and Graves, Alex and Kavukcuoglu, Koray},
  journal={arXiv preprint arXiv:1606.05328},
  year={2016}
}

autoregressive's People

Contributors

cheind avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

autoregressive's Issues

quantizer theory

https://colab.research.google.com/github/spatialaudio/digital-signal-processing-lecture/blob/master/quantization/linear_uniform_characteristic.ipynb#scrollTo=Oc-w6mUgXiiP

config.yaml files contain a `fit:` root element

this seems to cause problems when trying to actually perform a fit.

python -m autoregressive.scripts.train fit --config models\fseries_q127\config.yaml
usage: train.py [-h] [--config CONFIG] [--print_config [={comments,skip_null}+]] {fit,validate,test,predict,tune} ...
train.py: error: Configuration check failed :: Key "fit.data" is required but not included in config object or its value is None.```

However

python -m autoregressive.scripts.train --config models\fseries_q127\config.yaml fit

works

With the most current version 

python -m autoregressive.scripts.train fit --data autoregressive.datasets.MNISTDataModule --print_config > config.yaml

 does not generate a `fit:`.

plot weight histograms

if self.histograms:
layer = 'layer{}'.format(layer_index)
tf.histogram_summary(layer + '_filter', weights_filter)
tf.histogram_summary(layer + '_gate', weights_gate)
tf.histogram_summary(layer + '_dense', weights_dense)
tf.histogram_summary(layer + '_skip', weights_skip)

Decouple input and output representations

currently we assume that what we get as input is what we will predict as output (just shifted). However, thinking towards other research areas it might make sense that we rework that more generally:

model
  input: BxIxT
  output: BxQxT

where I might match Q but does not have to. In the training we would then have code like the following

def training_step(batch): 
  inputs = batch['x']
  if 't' in batch:
    targets = batch['t'] # allows us to provide alternative targets
  elif I == Q:
    targets = inputs[..., 1:]
    inputs = inputs[..., :-1]
  else:
    raise ValueError(...)

  logits = self.forward(inputs)
  loss = ce(logits, targets)

what's more is that we need to think about input transformers. Currently we use one-hot encoding hardwired into the model. We might instead consider a differentiable input_transform that is given to the model upon initialization. This would allow us to use differentiable embedding strategies.

Switch to quantizated encoding and one hot

discussion
https://github.com/ibab/tensorflow-wavenet/issues/83
https://github.com/ibab/tensorflow-wavenet/issues/219

also we could then use gumpel-trick to reparametrize categorical dist and allow unrolling during training.

in issue 83 above, note that they speak about a larger kernel for the initial conv.

What is meant by global conditioning only.

I am attempting to use your wavenet implementation to model some climate data, where my condition vector changes with time. The code mentions only global conditioning is currently supported. What exactly does this mean from an architecture perspective?

add image support

early support in image-support branch. v48 on gpus. seems to be a quite nice model so far. R=347, global condition=class,

top row: observed 392 px
bottom row: generated 392px
wavenet-mnist-r347-o392-g392

top row: observed 200 px
bottom row: generated 584 px
wavenet-mnist-r347-o200-g584

top row: observed 1 px
bottom row: generated 783 px
wavenet-mnist-r347-o1-g783b

top row: observed 200 px
bottom row: generated 584 px
wavenet-mnist-r347-o200-g584b

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.