Giter VIP home page Giter VIP logo

draw's Introduction

Build Status MIT

Implementation of the DRAW network architecture

This repository contains a reimplementation of the Deep Recurrent Attentive Writer (DRAW) network architecture introduced by K. Gregor, I. Danihelka, A. Graves and D. Wierstra. The original paper can be found at

http://arxiv.org/pdf/1502.04623

animation.gif

Dependencies

Draw currently works with the "cutting-edge development version". But since the API is subject to change, you might consider installing this known to be supported version:

You also need to install

Data

You need to set the location of your data directory:

export FUEL_DATA_PATH=/home/user/data

fuel-download and fuel-convert are used to obtain and convert training datasets. E.g. for binarized MNIST

cd $FUEL_DATA_PATH
fuel-download binarized_mnist
fuel-convert binarized_mnist

or similarly for SVHN

cd $FUEL_DATA_PATH
fuel-download svhn -d . 2
fuel-convert svhn -d . 2

Training with attention

To train a model with a 2x2 read and a 5x5 write attention window run

cd draw
./train-draw.py --dataset=bmnist --attention=2,5 --niter=64 --lr=3e-4 --epochs=100

On Amazon g2xlarge it takes more than 40min for Theano's compilation to end and training to start. If you enable the bokeh-server, once training starts you can track its live plotting. It will take about 2 days to train the model.

After each epoch it will save the following files:

  • a pickle of the model
  • a pickle of the log
  • sampled output image for that epoch
  • animation of sampled output

Generating animations

To generate sampled output including an animation run

python sample.py svhn_model.pkl --channels 3 --size 32

Note that in order to load a model and to generate samples all dependencies are needed. This unfortunately also this includes the GPU because python cannot unpickle CudaNdarray objects without it. This is a known problem that we don't yet a have general solution to.

SVHN

To train a model on SVHN

python train-draw.py --name=my_svhn --dataset=svhn2 \
  --attention=5,5 --niter=32 --lr=3e-4 --epochs=100 \
  --enc-dim 512 --dec-dim 512

After 100-200 epochs, the model above achieved a test_nll_bound of 1825.82.

Log

Run

python plot-kl.py [pickle-of-log]

to create a visualization of the KL divergence potted over inference iterations and epochs. E.g:

KL-Divergenc

Testing

Run

nosetests -v tests

to execute the testsuite. Run

cd draw
./attention.py

to test the attention windowing code on some image. It will open three windows: A window displaying the original input image, a window displaying some extracted, downsampled content (testing the read-operation), and a window showing the upsampled content (matching the input size) after the write operation.

draw's People

Contributors

dribnet avatar jbornschein avatar udibr avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.