Giter VIP home page Giter VIP logo

torchode's Introduction

A Parallel ODE Solver for PyTorch

pytest

torchode is a suite of single-step ODE solvers such as dopri5 or tsit5 that are compatible with PyTorch's JIT compiler and parallelized across a batch. JIT compilation often gives a performance boost, especially for code with many small operations such as an ODE solver, while batch-parallelization means that the solver can take a step of 0.1 for one sample and 0.33 for another, depending on each sample's difficulty. This can avoid performance traps for models of varying stiffness and ensures that the model's predictions are independent from the compisition of the batch. See the paper for details.

If you get stuck at some point, you think the library should have an example on x or you want to suggest some other type of improvement, please open an issue on github.

Installation

You can get the latest released version from PyPI with

pip install torchode

To install a development version, clone the repository and install in editable mode:

git clone https://github.com/martenlienen/torchode
cd torchode
pip install -e .

Usage

import matplotlib.pyplot as pp
import torch
import torchode as to

def f(t, y):
    return -0.5 * y

y0 = torch.tensor([[1.2], [5.0]])
n_steps = 10
t_eval = torch.stack((torch.linspace(0, 5, n_steps), torch.linspace(3, 4, n_steps)))

term = to.ODETerm(f)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
solver = to.AutoDiffAdjoint(step_method, step_size_controller)
jit_solver = torch.compile(solver)

sol = jit_solver.solve(to.InitialValueProblem(y0=y0, t_eval=t_eval))
print(sol.stats)
# => {'n_f_evals': tensor([26, 26]), 'n_steps': tensor([4, 2]),
# =>  'n_accepted': tensor([4, 2]), 'n_initialized': tensor([10, 10])}

pp.plot(sol.ts[0], sol.ys[0])
pp.plot(sol.ts[1], sol.ys[1])

Citation

If you build upon this work, please cite the following paper.

@inproceedings{lienen2022torchode,
  title = {torchode: A Parallel {ODE} Solver for PyTorch},
  author = {Marten Lienen and Stephan G{\"u}nnemann},
  booktitle = {The Symbiosis of Deep Learning and Differential Equations II, NeurIPS},
  year = {2022},
  url = {https://openreview.net/forum?id=uiKVKTiUYB0}
}

torchode's People

Contributors

martenlienen avatar slishak avatar dependabot[bot] avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.