Giter VIP home page Giter VIP logo

gwg_release's Introduction

GWG_release

Official release of code for "Oops I Took A Gradient: Scalable Sampling for Discrete Distributions" which has been accepted for a long presentation to ICML 2021.

The paper is by me, Kevin Swersky, Milad Hashemi, David Duvenaud, and Chris Maddison

See Gibbs-With-Gradients sampling from an Ising model:

Code for sampling experiments can be found in:

rbm_sample.py, ising_sample.py, fhmm_sample.py, potts_sample.py, svgd_sample.py

To generate training data for ising inference experiments run:

./generate_data.sh

Datasets for EBM training can be found at:

https://github.com/jmtomczak/vae_vampprior/tree/master/datasets

Download them and unzip as:

GWG_release/

    datasets/
        Caltech...
        FreyFaces...
        Histo...
        MNIST_static/
        Omniglot/

If you would like access to the protein data please contact me at [email protected], they are quite large and don't fit here :(

To train a binary EBM run:

python pcd_ebm_ema.py --save_dir $DIR} \
    --sampler gwg --sampling_steps $NUM_STEPS --viz_every 100 \
    --model resnet-64 --print_every 10 --lr .0001 --warmup_iters 10000 --buffer_size 10000 --n_iters 50000 \
    --buffer_init mean --base_dist --reinit_freq 0.0 \
    --eval_every 5000 --eval_sampling_steps 10000 &

To train a categorical EBM run:

python pcd_ebm_ema_cat.py --save_dir $DIR \
          --sampler gwg --sampling_steps $NUM_STEPS --viz_every 100 \
          --model resnet-64 --proj_dim $PROJ_DIM --print_every 10 --lr .0001 --warmup_iters 10000 --buffer_size 1000 \
          --n_iters 50000 \
          --buffer_init mean --base_dist --p_control 0.0 --reinit_freq 0.0 \
          --eval_every 5000 --eval_sampling_steps 10000 --dataset_name ${DATA}

To evaluate with AIS run:

This takes a while...

python eval_ais.py \
    --ckpt_path $CKPT_path \
    --save_dir $DIR \
    --sampler gwg --model resnet-64 --buffer_size 10000 \
    --n_iters 300000 --base_dist --n_samples 500 \
    --eval_sampling_steps 300000 --ema --viz_every 1000

gwg_release's People

Contributors

wgrathwohl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

gwg_release's Issues

No mask provided in code for DiffSampler with binary data

Hi, thank you for providing this code. I really enjoyed reading your paper. I tried running pcd_ebm_ema.py as per the README on Static MNIST and received the following error:

Traceback (most recent call last):
File "pcd_ebm_ema.py", line 303, in
main(args)
File "pcd_ebm_ema.py", line 196, in main
x_fake_new = sampler.step(x_fake.detach(), model).detach()
TypeError: step() missing 1 required positional argument: 'mask'

It looks like the DiffSampler expects a mask but is not receiving one.

Regarding MNIST-RBM

Hi authors!

Thank you for providing the source codes for the excellent work!

I have a question regarding the MNIST-RBM experiments (rbm_sample.py). It seems that all the methods are initialized with a gibbs sampling result (in other words, model.init_dist is changed by first running gibbs sampling for 5,000 steps). Did you try to sample without this initialization? Moreover, since all the methods are initialized by gibbs ( the yellow curve), why do they have a much higher log-MMD than the yellow curve? Finally, could I know what the algorithmic difference is between dim-gibbs (the blue curve) and gibbs (the yellow curve)?

image

Looking forward to your replay. Thanks in advance!!!

Question about the code for potts model.

Hi, thank you for providing this code!

in L185-186 (pcd_potts.py),

  • logp_real = (model(x).squeeze() * weights).mean()
  • logp_fake = model(x_fake).squeeze().mean()
  1. Why reweighting is only applied to x
  2. Should logp_real = (model(x).squeeze() * weights)/weights.sum() to be a weighted mean?

Applying to GWG to protein variants

I enjoyed reading your paper.

I've trained a neural network model to predict a specific functional measurement value for mutants of a given protein. I would like to use GWG to sample high-performing mutations. My protein of interest is of length ~1500 amino acids.

I'm wondering if this would be as simple as creating a DiffSamplerMultiDim class with approx=True (the dim parameter doesn't seem to be used anywhere) and temperature=2.

I think one-hot representation wouldn't work because it has a constraint that at a given position, only one index out of the 20 can be a 1, which the algorithm doesn't obey. So perhaps there can be some mapping from a 5 dimensional binary vector to an amino acid index.

I'm wondering if you think this would work? Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.