Giter VIP home page Giter VIP logo

jacobjinkelly / easy-neural-ode Goto Github PK

View Code? Open in Web Editor NEW
265.0 11.0 31.0 26.5 MB

Code for the paper "Learning Differential Equations that are Easy to Solve"

Home Page: https://arxiv.org/abs/2007.04504

License: MIT License

Python 100.00%
neural-ode differential-equations machine-learning jax deep-learning deep-neural-networks dynamical-systems neural-differential-equations numerical-integration ode ode-solver

easy-neural-ode's Issues

What's the difference between odeint and odeint_aux?

I have trouble understanding the naming in your ode lib. For example, it seems the augment 'rev_func' is never used in odeint_sepaux. And what is the difference between 'nodeint_aux' and 'all_nodeint' in ffjord_mnist?

Tutorial Application

It would be great to see how this regularized training method can be applied to the Implicit Layers Neural Ode Tutorial.

Just to increase one's understanding, it would be helpful to see (1) how the regularization fits into the training loop and (2) how jet differs from applying first order A.D. repeatedly.

I am going to try and tackle this over the next week, but would appreciate any help from anyone who may be interested! -- Thanks

Return type is tuple of solution and nfe not just solution

The return type is actually a tuple of the solution and the number of function evaluations:

return jax.vmap(unravel)(out), nfe

Super easy to clean up, I just thought it was some issues with pytrees the first time I came across it.

I'm currently going through the repo to understand the code base a bit better: would you prefer if I opened an issue for every tiny thing I find, or would it be easier to work with a general clean up PR?

toy examples

Are there any toy examples/tutorials available? I would be interested in applying this approach to other datasets

Source for Tanyam

Is there a reference for the Tanyam solver? I am unfamiliar with it, and don't see it online or in the paper

Error on trying to run the experiments

I am trying to run the experiments mnist.py and ffjord_mnist.py however I encounter a few runtime errors:

  1. mnist.py
Traceback (most recent call last):
  File "mnist.py", line 684, in <module>
    run()
  File "mnist.py", line 526, in run
    forward, model = init_model()
  File "mnist.py", line 275, in init_model
    ode_dim = jnp.prod(ode_shape[1:])
  File "/home/avik-pal/.local/lib/python3.7/site-packages/jax/numpy/lax_numpy.py", line 1797, in reduction
    _check_arraylike(name, a)
  File "/home/avik-pal/.local/lib/python3.7/site-packages/jax/numpy/lax_numpy.py", line 297, in _check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: prod requires ndarray or scalar arguments, got <class 'tuple'> at position 0.
  1. ffjord_mnist.py
Traceback (most recent call last):
  File "ffjord_mnist.py", line 730, in <module>
    run()
  File "ffjord_mnist.py", line 645, in run
    batch = next(ds_train)[0]
TypeError: '_IterableDataset' object is not an iterator

I am using the CPU version of jax. Would it be possible to get the version bounds for the dependencies?

Error running latent_ode.py

tried running the script on physionet data and get the following error after a few iterations, can you comment on this and also a bit more on what is the expected output:
TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

Upon running it again, it would just hang here:

 python latent_ode.py --reg r3 --lam 1e-2
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
~/.conda/envs/Neural_ODE/lib/python3.8/site-packages/jax/_src/random.py:511: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation
  warnings.warn(msg, FutureWarning)

conda environment:

_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       1_gnu    conda-forge
absl-py                   0.13.0                    <pip>
backcall                  0.2.0                     <pip>
ca-certificates           2021.5.30            ha878542_0    conda-forge
certifi                   2021.5.30        py38h578d9bd_0    conda-forge
cycler                    0.10.0                    <pip>
Cython                    0.29.19                   <pip>
debugpy                   1.3.0                     <pip>
dm-haiku                  0.0.5.dev0                <pip>
flatbuffers               2.0                       <pip>
future                    0.18.2                    <pip>
ipykernel                 6.0.0                     <pip>
ipython                   7.25.0                    <pip>
ipython-genutils          0.2.0                     <pip>
jax                       0.2.17                    <pip>
jaxlib                    0.1.68                    <pip>
jedi                      0.18.0                    <pip>
jmp                       0.0.2                     <pip>
joblib                    0.15.1                    <pip>
jupyter-client            6.1.12                    <pip>
jupyter-core              4.7.1                     <pip>
kiwisolver                1.2.0                     <pip>
ld_impl_linux-64          2.36.1               hea4e1c9_0    conda-forge
libffi                    3.3                  h58526e2_2    conda-forge
libgcc-ng                 9.3.0               h2828fa1_19    conda-forge
libgomp                   9.3.0               h2828fa1_19    conda-forge
libstdcxx-ng              9.3.0               h6de172a_19    conda-forge
matplotlib                3.2.1                     <pip>
matplotlib-inline         0.1.2                     <pip>
ncurses                   6.2                  h58526e2_4    conda-forge
numpy                     1.21.0                    <pip>
openssl                   1.1.1k               h7f98852_0    conda-forge
opt-einsum                3.3.0                     <pip>
parso                     0.8.2                     <pip>
pexpect                   4.8.0                     <pip>
phate                     1.0.7                     <pip>
pickleshare               0.7.5                     <pip>
pip                       21.1.3             pyhd8ed1ab_0    conda-forge
POT                       0.7.0                     <pip>
prompt-toolkit            3.0.19                    <pip>
ptyprocess                0.7.0                     <pip>
Pygments                  2.9.0                     <pip>
pyparsing                 2.4.7                     <pip>
python                    3.8.10          h49503c6_1_cpython    conda-forge
python-dateutil           2.8.1                     <pip>
python_abi                3.8                      2_cp38    conda-forge
pyzmq                     22.1.0                    <pip>
readline                  8.1                  h46c0cb4_0    conda-forge
s-gd2                     1.8                       <pip>
scikit-learn              0.23.1                    <pip>
scipy                     1.4.1                     <pip>
setuptools                49.6.0           py38h578d9bd_3    conda-forge
six                       1.15.0                    <pip>
sklearn                   0.0                       <pip>
sqlite                    3.36.0               h9cd32fc_0    conda-forge
tabulate                  0.8.9                     <pip>
threadpoolctl             2.1.0                     <pip>
tk                        8.6.10               h21135ba_1    conda-forge
torch                     1.5.0                     <pip>
torchdiffeq               0.0.1                     <pip>
tornado                   6.1                       <pip>
traitlets                 5.0.5                     <pip>
wcwidth                   0.2.5                     <pip>
wheel                     0.36.2             pyhd3deb0d_0    conda-forge
xz                        5.2.5                h516909a_1    conda-forge
zlib                      1.2.11            h516909a_1010    conda-forge

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.