Giter VIP home page Giter VIP logo

Comments (5)

patrick-kidger avatar patrick-kidger commented on June 4, 2024

Can you provide a MWE?

from optimistix.

hkortier avatar hkortier commented on June 4, 2024

sry very late but still relevant. A single shooting example:

import jax
import jax.numpy as jnp

import equinox as eqx  
import optimistix as optx
import optax
import diffrax as dfx  

from watermark import watermark

jax.config.update("jax_enable_x64", True)

def c_func(mach):
    return jnp.select([mach < 0.4, mach < 0.8, mach < 1.2], 
                      [0.1, .1 * (mach - 0.4) / 0.4 + 0.1, 0.25 * (mach - 0.8) / 0.4 + 0.25], default=.5)

class CannonODE(eqx.Module):
    c: float 
    g: float 

    def __call__(self, t, y, args):    
        v = y[1]
        T, = args
        speed = jnp.linalg.norm(v)

        mach = speed / 340.0
        c = c_func(mach)  
    
        dp = T * v
        dv = T * jnp.array([-c * v[0] * speed,
                        -c * v[1] * speed - self.g])

        return (dp, dv)
    
class CannonTrajectory(eqx.Module):
    ode: CannonODE

    def __init__(self, ode):
        self.ode = ode

    def __call__(self, parameter, saveat: dfx.SaveAt):
        QE, v0, T = parameter
        y0 = (jnp.array([0.0, 0.0]) , jnp.array([v0*jnp.cos(QE), v0*jnp.sin(QE)]))

        term = dfx.ODETerm(self.ode)
        stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
        solver = dfx.Tsit5()
        t0 = saveat.subs.ts[0]
        t1 = saveat.subs.ts[-1]
        dt0 = 0.01

        sol = dfx.diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        y0,
        args=(T,),
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        # support forward-mode autodiff, which is used by Levenberg--Marquardt
        adjoint=dfx.DirectAdjoint(),
        max_steps=1024
        )
        return sol

def residuals(parameter, args):
    traj, target = args
    saveat = dfx.SaveAt(ts=jnp.array([0., 1.]))
    pred_values = traj(parameter, saveat).ys[0][-1,:]
    return target - pred_values

def residuals_min(parameter, args):
    res = residuals(parameter, args)
    return jnp.sqrt(jnp.dot(res, res))

def main(target):
    v0 = 200.0
    QE0 = 0.01#jnp.pi/4
    T0 = 2.0

    ode = CannonODE(c=0.6, g=9.81)
    traj = CannonTrajectory(ode)

    init_parameter = jnp.array([QE0, v0, T0])

    solver = optx.OptaxMinimiser(optax.adabelief, rtol=1e-8, atol=1e-8)
    res = optx.minimise(residuals_min, solver, init_parameter, max_steps=128, throw=False, args=(traj, target))
    
    return res, traj, target

if __name__ == "__main__":
    print(watermark(packages="jax,jaxlib,optimistix,equinox,diffrax,optax"))
    target = jnp.array([100., 0.])
    res, traj, target = main(target)

output:

jax       : 0.4.20
jaxlib    : 0.4.14
optimistix: 0.0.5
equinox   : 0.11.2
diffrax   : 0.4.1
optax     : 0.1.7

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
.....
  File "/Users/hkortier/venvs/diffrax/lib/python3.10/site-packages/optimistix/_solver/optax.py", line 90, in init
    opt_state = self.optim.init(y)
AttributeError: '_Closure' object has no attribute 'init'

from optimistix.

patrick-kidger avatar patrick-kidger commented on June 4, 2024

Ah! You want optax.adabelief(...), not just optax.adabelief.

from optimistix.

hkortier avatar hkortier commented on June 4, 2024

ah thanks for you prompt reponse! I took this sentence from the https://docs.kidger.site/optimistix/how-to-choose/
optimistix.OptaxMinimiser(optax.adabelief, learning_rate=1e-3, rtol=1e-8, atol=1e-8)
However, lower in that text the correct syntax is listed.

from optimistix.

patrick-kidger avatar patrick-kidger commented on June 4, 2024

Ah, thank you for pointing out the mistake! This should now be fixed in #29, so I'm closing this.

from optimistix.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.