Giter VIP home page Giter VIP logo

maf's Introduction

๐ŸŽ„ Masked Autoregressive Flow with PyTorch

This is a PyTorch implementation of the masked autoregressive flow (MAF) by Papamakarios et al. [1].

The Gaussian MADE that makes up each layer in the MAF is found in MADE.py, while the MAF itself is found in maf.py.

Datasets

The files in the data folder are adapted from the original repository by G. Papamakarios [2]. G. Papamakarios et al. have kindly made the preprocessed datasets available to the public, and they can be downloaded through this link: https://zenodo.org/record/1161203#.Wmtf_XVl8eN.

Example

Remember to download the datasets first, and then run

python3 train.py

This will train a five layer MAF on the MNIST dataset. The size of each MAF layer (i.e., each Gaussian MADE) is set to be one layer of 512 hidden units. Following the approach of the original paper [1], we use the natural ordering of the inputs, and reverse it after each MAF layer. The model is trained using the Adam optimiser with a learning rate of 1e-4 and early stopping with a patience of 30. Please have a look inside train.py to see the rest of the hyperparameters and their default values.

By changing the number of hidden units in each MADE to 1024, the model will converge to a test log-likelihood similar to the one reported in [1]. Interestingly, fewer hidden units in each layers give better test results in terms of likelihood.

Do not hesitate to kill (ctrl + c) the training when the validation loss has increased for more than five consecutive epochs (don't worry, the best model (on val data) is already saved). In my experience, the validation loss will rarely start decreasing again after that. Alternatively, can manually change the patience by editing train.py.

Descent sample quality for MNIST is achieved in ~10 epochs, but further training squeezes out more performance in terms of higher average log-likelihood at test time.

The training runs smoothly when run locally on a Macbook Pro 2018 model with 16GB RAM. (Be prepared for a noisy fan and a burning hot laptop.)

Visualisations

An animation for a MAF with default settings trained for 20 epochs (one frame per epoch) on the MNIST dataset. The validation loss after 20 epochs was 1299.3 +/- 1.6, with error bands corresponding to two empirical standard deviations. After ~70 epochs, the validation loss was down to 1283.0 +/- 1.9.

alt text

Below are 80 random samples from the best saved model during training. Note that the samples are not sorted by likelihood, and that there are some garbage among them.

alt text

alt text This figure show the marginal distribution of 80 random pixels, based on 1000 random test samples. For a perfect data-to-noise mapping, we would expect each of the marginals to follow a standard Gaussian. This holds for some pixels, but it seems like the majority has a longer lower tail than desired.

alt text The same trend is reflected in these scatterplots, which under a perfect data-to-noise mapping would be bivariate standard Gaussians. However, the third quadrant is in many cases relatively overpopulated, reflecting the lower tails of the marginal distributions.

To do

To do: The validation loss diverges for the first few epochs for some of the datasets (not MNIST), before it stabilises and gets on the right track. Check:

  • The weight initialisation, which is different from the one used in the MAF paper.
  • The preprocessing of the datasets.
  • The batchnorm layer. Also, there are sometimes a (very) few test samples that are not well received by the model (even for MNIST). This causes the test likelihood to be artificially low, and its sample variance to be artificially high. Typically less than three samples with likelihoods that are orders of magnitude different from the rest of the samples. Find out what causes this, and how to deal with it.

[1] https://arxiv.org/abs/1705.07057

[2] https://github.com/gpapamak/maf/blob/master/datasets/

maf's People

Contributors

e-hulten avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

maf's Issues

FR: Volume Preserving Modification

Hi Edvard,

I was wondering if it is possible for you to add an option to make your MAF implementation be a volume preserving transform? I need this for a specific purpose, and it should be easy but I don't know your code well enough.

All one needs to do is make sure the last scale variable is the negative sum of the others. That way the log-determinant is always zero.

The idea is done with Real NVP in the paper here, but it is very simple:
https://arxiv.org/abs/2001.04872

Much appreciated if you can make this modification!

Transform the MAF layer into an IAF layer

Hi Edvard,

I need to invert a MAF layer (using another flow function and not an affine one), however I am having issues grasping the concept of it.
So:

  1. When training normalizing flows to use the change of variables we actually use f^{-1} of the transformer function. For an affine function this is (x - mu) * e^{0.5 * sigma} (x - mu) * torch.exp(0.5 * logp) in your code.
  2. When sampling from the distribution we use the forward function of the transformer. For an affine function like the one here that is done by u_i + mu_i * sigma_i. In your code that is mu[:, dim] + u[:, dim] * torch.exp(mod_logp[:, dim]).

Now if I invert that layer by saying forward = backward and backward=forward. Then I have f in the forward pass and f^{-1} in the backward. Is this the way to obtain an IAF?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.