Giter VIP home page Giter VIP logo

sixo's Introduction

SIXO: Smoothing Inference with Twisted Objectives

This repository contains code for the paper "SIXO: Smoothing Inference with Twisted Objectives" to appear at Neurips 2022. The paper is available as a preprint here.

Repository Structure

datasets.py              Dataset loaders
bounds.py                Functions for computing variational lower bounds e.g. SIXO, FIVO
smc.py                   Sequential Monte Carlo implementation
diffusion_train.py       A script for training Gaussian diffusion models
svm_train.py             A script for training the stochastic volatility model on forex data
hh_train.py              A script for training the Hodgkin-Huxley model
hh_eval.py               A script for evaluating trained Hodgkin-Huxley models
hh_inf.py                A script for performing inference in the Hodgkin-Huxley model
util.py                  Utility functions

models/                  Directory containing model implementations
 ├── base.py             Model superclasses
 ├── diffusion.py        Gaussian diffusion model implementation
 ├── hh.py               Hodgkin-Huxley model implementation
 └── svm.py              Stochastic volatility model implementation

data/                    Directory containing datasets
 └── forex_data.tsv      Foreign exchange dataset used for the SVM

test/                    Unit tests

bin/                                   Directory with bash scripts for launching training sweeps
 ├── gdd/                              Gaussian diffusion training scripts
 ├── hh/                               Hodgkin-Huxley training scripts
 │   ├── model_learning_sweep          Hodgkin-Huxley model learning scripts
 │   ├── eval_model_learning_sweep     Hodgkin-Huxley model evaluation scripts
 │   └── inf_sweep                     Hodgkin-Huxley inference scripts
 └── svm/                              SVM training scripts

Running the code

First, you must install the requirements with pip install -r requirements.txt. The main libraries used are JAX, numpy, scipy, matplotlib, tensorflow, pytest, chex, optax, equinox, and snax.

Then, you can run the python scripts directly, for example, the following command will train a SVM.

python3 svm_train.py \
  --bound=sixo \
  --train_num_particles=4 \
  --eval_num_particles=2048 \
  --eval_batch_size=1 \
  --num_train_steps=10000000 \
  --lr=2e-4 \
  --init_scale=0.01 \
  --min_scale_diag=1e-12 \
  --model_inner_steps=400 \
  --tilt_inner_steps=100 \
  --tilt_type=bwd_dre \
  --tilt_batch_size=64 \
  --tilt_lr=1e-3 \
  --dre_tilt_rnn_hdims=32 \
  --dre_tilt_mlp_hdims=128 \
  --summarize_every=20000 \
  --expensive_summary_mult=10 \
  --checkpoint_dir=/tmp/svm \
  --checkpoint_every=10000 \
  --checkpoints_to_keep=10 \
  --seed=0

Alternatively, you can run training using the scripts in the bin directory. Each directory has two bash files, run.sh and launch_sweep.sh. Each run.sh scripts accept a few arguments and launches an individual run. The launch_sweep.sh scripts iterate over hyperparameter combinations and launch a series of runs on a SLURM cluster. It is possible to use the run.sh scripts without a SLURM cluster by providing the proper command-line arguments locally. For example, the following command runs the same SVM training as above:

cd bin/svm && sh run.sh svm_train_proj sixo multinomial 2e-4 bwd_dre 1e-3 0

The run scripts assume that you have created a venv called .env in the root directory that includes all dependencies in requirements.txt.

Running the tests

To run the tests, make sure you have pytest installed and then run pytest in the main directory.

Citation

@inproceedings{lawson2022sixo,
  title={{SIXO}: Smoothing Inference with Twisted Objectives},
  author={Dieterich Lawson and Allan Ravent{\'o}s and Andrew Warrington and Scott Linderman},
  booktitle={Advances in Neural Information Processing Systems},
  editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
  year={2022},
  url={https://openreview.net/forum?id=bDyLgfvZ0qJ}
}

sixo's People

Contributors

dieterichlawson avatar

Stargazers

 avatar Alex Wang avatar  avatar

Watchers

 avatar Scott Linderman avatar Andrew Warrington avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.