Giter VIP home page Giter VIP logo

Comments (11)

johannahaffner avatar johannahaffner commented on July 19, 2024 1

Related: I opened google/jax#21581

from optimistix.

johannahaffner avatar johannahaffner commented on July 19, 2024

Hi Jason,

haha, I had the same thing on my To-Do list and just wrote an MWE. I came across this thing a few weeks ago, it would sure be handy to be able to vmap over initial conditions and then check stats, for example in a multi-start scenario. (Not super urgent for me.) My workaround has so far been to only return solution.value, with an optional return of solution.stats, and to ignore the rest.

Why does the solution object need to include the whole Jaxpr to begin with? It sure seems useful to be able to inspect that during debugging. But otherwise, I would have no need to look at it - maybe it could be made optional?. (I optimize over the parameters of ODE models, so my jaxprs are always super long.)

In this case, the jaxpr contains the following useful message

nonbatchable[
    allow_constant_across_batch=True
    msg=Nonconstant batch. `equinox.internal.while_loop` has received a batch of values that 
    were expected to be constant. This is probably an internal error in the library you are using.
]

However, the code does not fail with that message. It fails twice - once with

.../site-packages/jax/_src/interpreters/batching.py:1107], in matchaxis(axis_name, sz, src, dst, x, sum_match)
   1105   _ = core.get_aval(x)
   1106 except TypeError as e:
-> 1107   raise TypeError(f"Output from batched function {x!r} with type "
   1108                   f"{type(x)} is not a valid JAX type") from e

And once with

.../site-packages/jax/_src/core.py:1455], in concrete_aval(x)
   1454   return concrete_aval(x.__jax_array__())
-> 1455 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
   1456                  "type")

For what its worth, here is the MWE - even if it is now probably redundant :) I added BFGS as a solver.

Edit: print the number of lines in the jaxpr (1.2k, more than 100k characters.)
Edit Nr. 2: add forward- and backward options.
Edit Nr. 3: condense MWE. (Also removed line count for jaxpr.)

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from jaxtyping import Array

import equinox as eqx
import diffrax as dfx
import optimistix as optx

import pytest

class ToyModel(eqx.Module):
    _term: dfx.ODETerm

    def __init__(self):
        def dydt(t, y, k):  # Monoexponential decay
            return - k * y
        self._term = dfx.ODETerm(dydt)

    def __call__(self, param):
        t0 = 0.
        t1 = 10.
        dt0 = 0.01
        y0 = jnp.array([10.])
        
        sol = dfx.diffeqsolve(
            self._term, 
            dfx.Tsit5(), 
            t0, t1, dt0, y0, args=param,
            saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1)),
            adjoint=dfx.DirectAdjoint(),  # Supports both fwd and bwd autodiff
        )
        return sol.ys

def estimate_parameters(initial_guess, model, data, solver, solver_options: dict = dict(jac="fwd")):
    """Function that estimates the parameters."""

    def residuals(param, args):
        model, data = args
        fit = model(param)
        res = data - fit
        return res

    sol = optx.least_squares(
        residuals, 
        solver, 
        initial_guess,
        args = (model, data),
        options=solver_options,
    )
    return sol

model = ToyModel()
k = jnp.array([0.5])  # True value
ode_solution = model(k)

k0s = jnp.transpose(jnp.array([jnp.arange(0.0, 0.45, 0.05)]))

bfgs = optx.BFGS(atol=1e-09, rtol=1e-06)
lm = optx.LevenbergMarquardt(atol=1e-09, rtol=1e-06)
gn = optx.GaussNewton(atol=1e-09, rtol=1e-06)

# Add a vmap on top
vmapped_fwd_solve = jax.vmap(estimate_parameters, in_axes=(0, None, None, None))
vmapped_bwd_solve = jax.vmap(jtu.Partial(estimate_parameters, solver_options=dict(jac="bwd")), in_axes=(0, None, None, None))

for solver in [bfgs, lm, gn]:
    sol_bwd = vmapped_bwd_solve(k0s, model, ode_solution, solver)
    assert isinstance(sol_bwd, optx.Solution)

    if solver == bfgs:
        sol_fwd = vmapped_fwd_solve(k0s, model, ode_solution, solver)
        assert isinstance(sol_fwd, optx.Solution)
    else:
        with pytest.raises(TypeError):
            sol_fwd = vmapped_fwd_solve(k0s, model, ode_solution, solver)
print("Checks passed, expected errors raised. This is what happens: ")

# Repeat one of the calls that raises the error
vmapped_fwd_solve(k0s, model, ode_solution, lm)  # This fails

from optimistix.

johannahaffner avatar johannahaffner commented on July 19, 2024

I realise this is a bit long for an MWE - only the last line fails.

from optimistix.

johannahaffner avatar johannahaffner commented on July 19, 2024

I dug a little into _make_f_info from the gauss_newton module and it seems to me that the lambda function in jax.linearize(...) is redundant. If you print the returned linearized function lin_fn, you see that it has a jaxpr - which the output of pure jax.linearize(...) does not.

What I do not understand is why there has to be an auxiliary argument, but I worked around it with a wrapper for now:

import pytest

import jax
import jax.numpy as jnp
from jax.core import Jaxpr

import equinox as eqx
from optimistix._solver.gauss_newton import _make_f_info

def _for_jacrev(_y):
    """Copied from: optimistix._solver.gauss_newton"""
    f_eval, aux_eval = fn(_y, args)  # Why does tnis assume an auxiliary output?
    return f_eval, (f_eval, aux_eval)

def _no_nan(x):
    """Compied from test/helpers.py in diffrax."""
    if eqx.is_array(x):
        return x.at[jnp.isnan(x)].set(8.9568)  # arbitrary magic value
    else:
        return x

def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
    """Copied from test/helpers.py in diffrax."""
    if equal_nan:
        x = jtu.tree_map(_no_nan, x)
        y = jtu.tree_map(_no_nan, y)
    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)


def residuals(origin, args):
    del args
    def shifted_parabola(x0):
        x = jnp.linspace(0, 10)
        return (x - x0)**2
    true = shifted_parabola(2)
    fit = shifted_parabola(origin)
    return true - fit

def aux_wrapper(origin, args):
    return residuals(origin, args), None

initial_guess = 1.
args = ()
jac_bwd = jax.jacrev(residuals)(initial_guess, args)
jac_fwd = jax.jacfwd(residuals)(initial_guess, args)
assert tree_allclose(jac_bwd, jac_fwd)

with pytest.raises(ValueError):  # Only works with aux wrapper
    _make_f_info(residuals, initial_guess, args, set(), "bwd")

# Now with auxiliary output wrapper
(residual_jac_bwd, _) = _make_f_info(aux_wrapper, initial_guess, args, set(), "bwd")
assert tree_allclose(residual_jac_bwd.jac.pytree, jac_bwd)
(residual_jac_fwd, _) = _make_f_info(aux_wrapper, initial_guess, args, set(), "fwd")
assert isinstance(residual_jac_fwd.jac.fn.jaxpr, Jaxpr)

# The following snipped (jax.linearize...) is copied from _make_f_info (line 174)
with pytest.raises(ValueError):  # Again, must have aux
    f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess, has_aux=False)
    f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess)
    f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: residuals(_y, args), initial_guess)  # Value error: residuals does not have aux

f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess, has_aux=True)
assert aux_eval is None
assert tree_allclose(f_eval, residuals(initial_guess, args))
print(lin_fn)  # lin_fn has jaxpr

# Compare to jax.linearize without the lambda function
res, residuals_jvp = jax.linearize(residuals, *(initial_guess, args))
assert tree_allclose(res, residuals(initial_guess, args))
assert tree_allclose(jac_bwd, residuals_jvp(initial_guess, args))

res, residuals_jvp, aux = jax.linearize(aux_wrapper, *(initial_guess, args), has_aux=True)
assert tree_allclose(res, residuals(initial_guess, args))
assert tree_allclose(jac_bwd, residuals_jvp(initial_guess, args))
assert aux is None

from optimistix.

johannahaffner avatar johannahaffner commented on July 19, 2024

By now you can skip most of my thought process above :) I believe it is a somewhat subtle issue involving the eval shapes passed to the linear operator in _make_f_info from gauss_newton.py. The jvps computed using jax.linearize contain jaxprs in both cases, but one evaluates to the correct jacobian and one does not.

Using the MWE below, I get

TypeError: Expected PyTreeDef((*, ())), got PyTreeDef(((*, ()),))

MWE

import pytest

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import equinox as eqx
import lineax as lx
from optimistix._solver.gauss_newton import _make_f_info

def _no_nan(x):
    """Compied from test/helpers.py in diffrax."""
    if eqx.is_array(x):
        return x.at[jnp.isnan(x)].set(8.9568)  # arbitrary magic value
    else:
        return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
    """Copied from test/helpers.py in diffrax."""
    if equal_nan:
        x = jtu.tree_map(_no_nan, x)
        y = jtu.tree_map(_no_nan, y)
    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)

def fn(y, args): # Optimistix insists on aux, it seems: return extra None
    del args
    def shifted_parabola(x0):
        x = jnp.linspace(0, 10)
        return (x - x0)**2
    true = shifted_parabola(2)
    fit = shifted_parabola(y)
    return true - fit, 0  # Return zero as aux

y0 = 1.
nothing = ()

# Compute jacobian two ways, ignore fn_eval, aux_eval
_, fn_j, _ = jax.linearize(lambda _y: fn(_y, nothing), y0, has_aux=True)  # Status quo
_, fn_j_no_lambda, _ = jax.linearize(fn, *(y0, nothing), has_aux=True)

# Now check the Jacobians
true_jac_eval, _ = jax.jacfwd(fn)(y0, nothing)  # Throw away aux_eval
with pytest.raises(TypeError):
    assert tree_allclose(fn_j(y0, nothing), true_jac_eval)  # lambda used: input is unexpected pytree 
assert tree_allclose(fn_j_no_lambda(y0, nothing), true_jac_eval)

# Check eval shapes
fn_eval_shape, aux_eval_shape = jax.eval_shape(fn, *(y0, nothing))  # Returns tuple that includes aux
fn_j_no_lambda_eval_shape = jax.eval_shape(fn_j_no_lambda, *(y0, nothing))
assert tree_allclose(fn_eval_shape, fn_j_no_lambda_eval_shape)
with pytest.raises(TypeError):
    jax.eval_shape(fn_j, *(y0, nothing))  # unexpected pytree, again
   
# Create lx.FunctionLinearOperator
lx.FunctionLinearOperator(fn_j_no_lambda, jax.eval_shape(lambda x: x, (y0, nothing)))

from optimistix.

johannahaffner avatar johannahaffner commented on July 19, 2024

Here is my latest iteration, still poking at the two lines in _make_f_info from Gauss Newton.
I could show that the jaxpr is not causing the issue - at least not outside of FunctionLinearOperator, where vmapping over a jacobian that contains a jaxpr and is an output of jax.linearize raises no errata.

import pytest

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import equinox as eqx
import lineax as lx

def _no_nan(x):
    """Compied from test/helpers.py in diffrax."""
    if eqx.is_array(x):
        return x.at[jnp.isnan(x)].set(8.9568)  # arbitrary magic value
    else:
        return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
    """Copied from test/helpers.py in diffrax."""
    if equal_nan:
        x = jtu.tree_map(_no_nan, x)
        y = jtu.tree_map(_no_nan, y)
    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)

def fn(y):
    def shifted_parabola(x0):
        x = jnp.linspace(0, 10)
        return (x - x0)**2
    true = shifted_parabola(2.)  # True value
    fit = shifted_parabola(y)
    return true - fit

def aux_wrapper(y):
    return fn(y), None

y0 = 1.  # starting guess
y0s = jnp.arange(0., 4., 0.1)  # Many initial values

# Get jacobians the simple way
_, jac_of_fn = jax.linearize(fn, y0)
_, jac_of_aux_wrapper, _ = jax.linearize(aux_wrapper, y0, has_aux=True)
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_fn(y0)) 
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_aux_wrapper(y0))  

vmapped_jac_of_fn = jax.vmap(jac_of_fn)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper = jax.vmap(jac_of_aux_wrapper)(y0s) # Does not raise error

# Now include the lambda function in jax.linearize (status quo in optimistix)   
_, jac_of_fn_with_lambda = jax.linearize(lambda _y: fn(_y), y0)
_, jac_of_aux_wrapper_with_lambda, _ = jax.linearize(lambda _y: aux_wrapper(_y), y0, has_aux=True)

vmapped_jac_of_fn_with_lambda = jax.vmap(jac_of_fn_with_lambda)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper_with_lambda = jax.vmap(jac_of_aux_wrapper_with_lambda)(y0s) # Does not raise error

# Context: using lambda functions produces subtle difference in pytrees, not legible when examining pytreedef (as a human)
with pytest.raises(AssertionError):
    assert jtu.tree_structure(jac_of_fn) == jtu.tree_structure(jac_of_fn_with_lambda)
assert str(jtu.tree_structure(jac_of_fn)) == str(jtu.tree_structure(jac_of_fn_with_lambda))

# Create a lineax Linear Operator
def lin_fun(y):
    return 2 * y
lin_op = lx.FunctionLinearOperator(lin_fun, jax.eval_shape(lin_fun, y0))  # Confirm that it works in this case

with pytest.raises(ValueError): # I don't understand why it does not work in these cases
    lx.FunctionLinearOperator(jac_of_fn, jax.eval_shape(jac_of_fn, y0))
    lx.FunctionLinearOperator(jac_of_fn_with_lambda, jax.eval_shape(jac_of_fn_with_lambda, y0))

from optimistix.

tjltjl avatar tjltjl commented on July 19, 2024

A simple workaround: dataclasses.replace() the offending member of the output

from optimistix.

johannahaffner avatar johannahaffner commented on July 19, 2024

The tricky thing is that I can't figure out what the offending member is.

from optimistix.

johannahaffner avatar johannahaffner commented on July 19, 2024

And it has now been shown that it is a deeper issue in jax.linearize, which produces pytrees with nonidentical structure even for identical input functions, called with identical inputs.

from optimistix.

patrick-kidger avatar patrick-kidger commented on July 19, 2024

I think I understand what's going on here. The output of optx.least_squares includes a jaxpr inside of out.state. This isn't an arraylike object, so JAX doesn't understand how to handle it as an output of the vmap. Morally speaking, what's going on here is the same as jax.vmap(lambda x: object())(...), in which again non-array-like object is being returned.

The solution is pretty simple: use eqx.filter_vmap instead. This passes through all non-array-like objects unchanged. Indeec the use case in this issue is the raison d'etre of eqx.filter_vmap!

Does this solve the issues everyone is facing?

from optimistix.

johannahaffner avatar johannahaffner commented on July 19, 2024

Oh dear :D

It does solve my issue. I was actually in the process of replacing all vmaps with filter_vmaps, but there were still some around. Not anymore, though!

from optimistix.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.