Giter VIP home page Giter VIP logo

jacobjinkelly / easy-neural-ode Goto Github PK

View Code? Open in Web Editor NEW
264.0 11.0 30.0 26.5 MB

Code for the paper "Learning Differential Equations that are Easy to Solve"

Home Page: https://arxiv.org/abs/2007.04504

License: MIT License

Python 100.00%
neural-ode differential-equations machine-learning jax deep-learning deep-neural-networks dynamical-systems neural-differential-equations numerical-integration ode

easy-neural-ode's Introduction

Learning Differential Equations that are Easy to Solve

Code for the paper:

Jacob Kelly*, Jesse Bettencourt*, Matthew James Johnson, David Duvenaud. "Learning Differential Equations that are Easy to Solve." Neural Information Processing Systems (2020). [arxiv] [bibtex]

*Equal Contribution

Includes JAX implementations of the following models:

Includes JAX implementations of the following adaptive-stepping numerical solvers:

  • Heun-Euler heun (2nd order)
  • Fehlberg (RK1(2)) fehlberg (2nd order)
  • Bogacki-Shampine bosh (3rd order)
  • Cash-Karp cash_karp (4th order)
  • Fehlberg rk_fehlberg (4th order)
  • Owrenzen owrenzen (4th order)
  • Dormand-Prince dopri (5th order)
  • Owrenzen owrenzen5 (5th order)
  • Tanyam tanyam (7th order)
  • Adams adams (adaptive order)
  • RK4 rk4 (4th order, fixed step-size)

Requirements

Python

Please use python>=3.8

JAX

Follow installation instructions here.

Haiku

Follow installation instructions here.

Tensorflow Datasets

For using the MNIST dataset, follow installation instructions here.

Usage

Different scripts are provided for each task and dataset.

MNIST Classification

python mnist.py --reg r3 --lam 6e-5

Latent ODEs

python latent_ode.py --reg r3 --lam 1e-2

FFJORD (Tabular)

python ffjord_tabular.py --reg r2 --lam 1e-2

FFJORD (MNIST)

python ffjord_mnist.py --reg r2 --lam 3e-4

Datasets

MNIST

tensorflow-datasets (instructions for installing above) will download the data when called from the training script.

Physionet

The file physionet_data.py, adapted from Latent ODEs for Irregularly-Sampled Time Series will download and process the data when called from the training script. A preprocessed version is available in releases.

Tabular (FFJORD)

Data must be downloaded following instructions from gpapamak/maf and placed in data/. Only MINIBOONE is needed for experiments in the paper.

Code in datasets/, adapted from Free-form Jacobian of Reversible Dynamics (FFJORD), will create an interface for the MINIBOONE dataset once it's downloaded. It is called from the training script.

Acknowledgements

Code in lib is modified from google/jax under the license.

Several numerical solvers were adapted from torchdiffeq and DifferentialEquations.jl.

BibTeX

@inproceedings{kelly2020easynode,
  title={Learning Differential Equations that are Easy to Solve},
  author={Kelly, Jacob and Bettencourt, Jesse and Johnson, Matthew James and Duvenaud, David},
  booktitle={Neural Information Processing Systems},
  year={2020},
  url={https://arxiv.org/abs/2007.04504}
}

easy-neural-ode's People

Contributors

jacobjinkelly avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

easy-neural-ode's Issues

toy examples

Are there any toy examples/tutorials available? I would be interested in applying this approach to other datasets

Source for Tanyam

Is there a reference for the Tanyam solver? I am unfamiliar with it, and don't see it online or in the paper

Error on trying to run the experiments

I am trying to run the experiments mnist.py and ffjord_mnist.py however I encounter a few runtime errors:

  1. mnist.py
Traceback (most recent call last):
  File "mnist.py", line 684, in <module>
    run()
  File "mnist.py", line 526, in run
    forward, model = init_model()
  File "mnist.py", line 275, in init_model
    ode_dim = jnp.prod(ode_shape[1:])
  File "/home/avik-pal/.local/lib/python3.7/site-packages/jax/numpy/lax_numpy.py", line 1797, in reduction
    _check_arraylike(name, a)
  File "/home/avik-pal/.local/lib/python3.7/site-packages/jax/numpy/lax_numpy.py", line 297, in _check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: prod requires ndarray or scalar arguments, got <class 'tuple'> at position 0.
  1. ffjord_mnist.py
Traceback (most recent call last):
  File "ffjord_mnist.py", line 730, in <module>
    run()
  File "ffjord_mnist.py", line 645, in run
    batch = next(ds_train)[0]
TypeError: '_IterableDataset' object is not an iterator

I am using the CPU version of jax. Would it be possible to get the version bounds for the dependencies?

What's the difference between odeint and odeint_aux?

I have trouble understanding the naming in your ode lib. For example, it seems the augment 'rev_func' is never used in odeint_sepaux. And what is the difference between 'nodeint_aux' and 'all_nodeint' in ffjord_mnist?

Tutorial Application

It would be great to see how this regularized training method can be applied to the Implicit Layers Neural Ode Tutorial.

Just to increase one's understanding, it would be helpful to see (1) how the regularization fits into the training loop and (2) how jet differs from applying first order A.D. repeatedly.

I am going to try and tackle this over the next week, but would appreciate any help from anyone who may be interested! -- Thanks

Return type is tuple of solution and nfe not just solution

The return type is actually a tuple of the solution and the number of function evaluations:

return jax.vmap(unravel)(out), nfe

Super easy to clean up, I just thought it was some issues with pytrees the first time I came across it.

I'm currently going through the repo to understand the code base a bit better: would you prefer if I opened an issue for every tiny thing I find, or would it be easier to work with a general clean up PR?

Error running latent_ode.py

tried running the script on physionet data and get the following error after a few iterations, can you comment on this and also a bit more on what is the expected output:
TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

Upon running it again, it would just hang here:

 python latent_ode.py --reg r3 --lam 1e-2
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
~/.conda/envs/Neural_ODE/lib/python3.8/site-packages/jax/_src/random.py:511: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation
  warnings.warn(msg, FutureWarning)

conda environment:

_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       1_gnu    conda-forge
absl-py                   0.13.0                    <pip>
backcall                  0.2.0                     <pip>
ca-certificates           2021.5.30            ha878542_0    conda-forge
certifi                   2021.5.30        py38h578d9bd_0    conda-forge
cycler                    0.10.0                    <pip>
Cython                    0.29.19                   <pip>
debugpy                   1.3.0                     <pip>
dm-haiku                  0.0.5.dev0                <pip>
flatbuffers               2.0                       <pip>
future                    0.18.2                    <pip>
ipykernel                 6.0.0                     <pip>
ipython                   7.25.0                    <pip>
ipython-genutils          0.2.0                     <pip>
jax                       0.2.17                    <pip>
jaxlib                    0.1.68                    <pip>
jedi                      0.18.0                    <pip>
jmp                       0.0.2                     <pip>
joblib                    0.15.1                    <pip>
jupyter-client            6.1.12                    <pip>
jupyter-core              4.7.1                     <pip>
kiwisolver                1.2.0                     <pip>
ld_impl_linux-64          2.36.1               hea4e1c9_0    conda-forge
libffi                    3.3                  h58526e2_2    conda-forge
libgcc-ng                 9.3.0               h2828fa1_19    conda-forge
libgomp                   9.3.0               h2828fa1_19    conda-forge
libstdcxx-ng              9.3.0               h6de172a_19    conda-forge
matplotlib                3.2.1                     <pip>
matplotlib-inline         0.1.2                     <pip>
ncurses                   6.2                  h58526e2_4    conda-forge
numpy                     1.21.0                    <pip>
openssl                   1.1.1k               h7f98852_0    conda-forge
opt-einsum                3.3.0                     <pip>
parso                     0.8.2                     <pip>
pexpect                   4.8.0                     <pip>
phate                     1.0.7                     <pip>
pickleshare               0.7.5                     <pip>
pip                       21.1.3             pyhd8ed1ab_0    conda-forge
POT                       0.7.0                     <pip>
prompt-toolkit            3.0.19                    <pip>
ptyprocess                0.7.0                     <pip>
Pygments                  2.9.0                     <pip>
pyparsing                 2.4.7                     <pip>
python                    3.8.10          h49503c6_1_cpython    conda-forge
python-dateutil           2.8.1                     <pip>
python_abi                3.8                      2_cp38    conda-forge
pyzmq                     22.1.0                    <pip>
readline                  8.1                  h46c0cb4_0    conda-forge
s-gd2                     1.8                       <pip>
scikit-learn              0.23.1                    <pip>
scipy                     1.4.1                     <pip>
setuptools                49.6.0           py38h578d9bd_3    conda-forge
six                       1.15.0                    <pip>
sklearn                   0.0                       <pip>
sqlite                    3.36.0               h9cd32fc_0    conda-forge
tabulate                  0.8.9                     <pip>
threadpoolctl             2.1.0                     <pip>
tk                        8.6.10               h21135ba_1    conda-forge
torch                     1.5.0                     <pip>
torchdiffeq               0.0.1                     <pip>
tornado                   6.1                       <pip>
traitlets                 5.0.5                     <pip>
wcwidth                   0.2.5                     <pip>
wheel                     0.36.2             pyhd3deb0d_0    conda-forge
xz                        5.2.5                h516909a_1    conda-forge
zlib                      1.2.11            h516909a_1010    conda-forge

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.