Giter VIP home page Giter VIP logo

optimistix's Introduction

JAX scientific ML ecosystem:

Probably the reason you're here. I would highlight:

  1. Equinox: elegant neural networks. GitHub Repo stars

  2. Diffrax: numerical ODE/SDE solvers. GitHub Repo stars

  3. jaxtyping: shape/dtype annotations for arrays. GitHub Repo stars

  4. Lineax: linear/least-squares solvers. GitHub Repo stars (new!)

  5. Optimistix: root finding, least squares, etc. GitHub Repo stars (new!)

Other links:

Me:

I currently wear multiple hats across bio/ML/CS at Cradle.bio. These days I am generally interested in scientific ML, and specifically the application of ML to unsolved problems in biology!

I also hold an honorary lectureship at Imperial College London. In past lives I previously wore the same multitude of hats at Google X, and did my PhD at the University of Oxford.

optimistix's People

Contributors

colcarroll avatar packquickly avatar patrick-kidger avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

optimistix's Issues

Qs: Mapping different solvers to leaves, parameter normalisation, different parameter scales

Hey, first up another awesome package in the Jax eco-system! I've been meaning to incorporate these kind of solvers in my work for a long time, so thanks for for making it easy 😛. This is partly a discussion post as I am relatively unfamiliar with these algorithms, I have done my best to parse the docs in detail, but feel free to point me to any external resources as I would love to learn more.


Mapping solvers to different leaves

Is there a way that we can map different solvers to each leaf of a pytree? Lets say we know one parameter will be initialised in the smooth bowl of the loss space and can be solved with BFGS, but the other parameter has a 'noisy' loss topology and is best tackled with a regular optax optimiser. This is actually quite typical for the sort of optical models I work with, although not super common in general AFAIK.

It is simple to apply each of these algorithms one at a time to each pytree leaf with eqx.partition and eqx.combine. This approach works but can't 'jointly' optimise these leaves and would result in redundant calculation of the model gradients, since the grads from each evaluation could be passed to both algorithms.

Now I recognise that a 'joint' approach would pose a problem for algorithms like BFGS since it would be trying to find the minimum of a dynamic topology that changes as the other leaves are updated throughout the optimisation. I would be curious as to what you think might be the right approach to this kind of problem, maybe there are solvers designed for this sort of problem? If not what approach might you take, I'm very excited about the flexibility and extensibility of this software to be able to build out much better custom solvers for my rather niche set of optimisation problems.


Parameter normalisation during the solve loop

So during a gradient descent loop we commonly need to apply some normalisation/regularisation to our updated parameters to ensure they are 'physical'. An example would be normalising relative spectral weights to have a sum of 1 after the updates have been applied. I am wondering if there is a way to enforce these constraints during the solve. The simplest example case here would be preventing some values from being above some threshold.

I would guess this would likely be possible through a user-defined solver class, that applies the custom regularisation. If something like this is possible, how would it be implemented? From a crude look at the code it looks like this could be done within the step function of the AbstractGradientDescent class?


Parameters with large scale variation

So this one is more of an open discussion, rather than a specific question. It's very common for the models I work with to have vastly different scales (everything from 1e10 to 1e-10). This is a problem for these algorithms in general, so I was hoping to get your thoughts on what would be the right way to approach a solution.

There is the 'naive' solution where you apply a scaling to each parameter of the pytree before passing it into the minimisation function, and then inverting the scaling once inside the function. Now this works but is far from what I would consider ideal as it still requires a degree of prior knowledge of the model and sort of just kicks the tunable hyper-parameter from a learning rate into a scaling. Granted this is still going to be generally more robust, but I feel like there is something more elegant... I'm wondering if you have any thoughts or ideas about this!


Anyway thanks again for the excellent software and the help!

correct name of the exception class that Equinox uses for runtime errors

Hi,

I'm using a couple Equinox pytrees in my program and in one case it is used in conjunction with Newton root finding from Optimistix. My larger code is a gradient descent variation and occasionally a data point will be expected to not have a solution to the root finding algorithm. In trying to set up try: except: , what is the correct name of the exception class that Equinox uses for runtime errors, or should it be something from optimistix?

instance_of_acceleration = AccelerationPytree(l_pr, regime, kinetic_conservative, rot_dissapative, ld_dissapative, epd_dissapative_1, qe_conservative_1, epd_dissapative_2, epd_dissapative_3, epd_dissapative_4, qe_conservative_2, qe_conservative_3)
 
 solver_root = optx.Newton(rtol=1e-8, atol=1e-8)
 y0 = (jnp.array(0.1))
 try:
     sol = optx.root_find(fn=time_root_from_distance, solver=solver_root (well_posed=False), y0=y0, args=instance_of_acceleration, options=dict(lower=0.), max_steps=20000, throw=False)
     Thv = sol.value
 except eqx.exception_module.EqxRuntimeError. (WHAT GOES HERE?):

     #Set Thv to a default value or handle it accordingly
     Thv = 999.  # Replace with an appropriate default value or action
 print(Thv)
 return Thv

Please and thanks,
Tom

Efficient NewtonCG Implementation

Hi all, thanks for the phenomenal library. We're already using it in several statistical genetics methods in my group!

I've been porting over some older code of mine to use optimistix, rather than hand-rolled inference procedures and could use some advice. Currently, I am performing some variational inference using a mix of closed-form updates for variational parameters, as well as gradient-based updates for some hyperparameters. It -roughly- works like,

while True:
  eval_f = jax.value_and_grad(_infer, has_aux=True)
  ((value, var_params), gradient) = eval_f(hyper_param, var_params, data)
  hyper_param = hyper_param + learning_rate * gradient
  if converged:
    break

I'd -like- to retool the above to not only report the current value, aux values (i.e. updated variational parameters), and gradient wrt hyper param, but return a -hvp- function that could be used in a Newton CG like step in Optimistix. I know of the new minimize function, but what isn't clear is how to set up the scenario to not only report gradients, but also return a hvp function internally without having to take two additional passes over the graph (i.e. once for value and grad, another two for hvp => forward + backward).

Is this doable? Apologies if this is somewhat nebulous--I'm happy to clarify.

Improve `IndirectIterativeDual`

  • Add IndirectIterativeDual specific Newton safeguards (Conn, Gould, and Toint "Trust Region Methods" section 7.3)
  • Use Given's rotations to compute diff for different values of λ more efficiently (Conn, Gould, and Toint section 7.3 or Nocedal Wright section 4.3.) Depends on google/lineax#6

grad of vmap of function which wraps an optax solver occasionally fails

Hi,
I previously had the optx newton root finding algorithm in operation which used a jnp.where to set a default value when the root_finder couldn't find a solution. It worked to insert the default value but the program would fail to find the gradient when default value was implemented.
I ended up moving to the optx minimizer wrapper for a optax solver to minimize a func in place of a root finding operation and this works very nicely as it handles the more extreme slopes that occur in my functions.
But then params outside were changed such that two long lumbers equalled the negative of each other with x64 precision. The point is not the need to buy a lottery ticket but that I need a way to make grad work when the solver cannot find a solution.

Specifics: I use vmap to fill out the elements of a 1D array, by calling a function by vmap for each elemen of array. That function includes the following code:

new code returns a value = y0 and grad = crash when solution does not exist

optimizer_acc = optx.OptaxMinimiser(optax.adabelief(learning_rate=1e-2), rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.01))
sol = optx.minimise(fn=time_root_from_distance, solver = optimizer_acc, y0 = y0, args = instance_of_acceleration, options=dict(lower=0.),max_steps=10000, throw=False)

old code returns default value but crashes when grad requested:

solver_root = optx.Newton(rtol=1e-5, atol=1e-4)
    y0 = (jnp.array(0.01))
    sol = optx.root_find(fn=time_root_from_distance, solver = solver_root, y0 = y0, args = instance_of_acceleration, options=dict(lower=0.),max_steps=10000, throw=False)    
    Thv = jnp.where(sol.result == optx.RESULTS.successful, sol.value, 9999.)

Both work when they can find a solution. But when a solution does not exist, cannot be found, jax.grad(Objective) fails.

Not sure if this question is misplaced but any suggestions on an approach to return a grad not just a value when solution is absent would be appreciated.

Thanks,
Tom

Usage with 'vmap'

Hi, looks like a very promising library, this bit in the docs got me interested:

Unlike the SciPy implementation of Newton's method, the Optimistix version also works for vector-valued (or PyTree-valued) y.

Does this mean that the function passed to the root finder has to be vectorized in the traditional sense? Do the root finders here support functions which rely on vmap? I couldn't find anything in the docs about this. Thanks!

New solvers

  • Anderson acceleration
  • LBFGS
  • Affine

Powell's (unconstrained) derivative free optimisers:

  • UOBYQA
  • NEWUOA

On an affine solvers: such systems can be handled with a single linear solve. JAX can detect affine functions via

import jax
import jax.interpreters.partial_eval as pe

def is_affine(f, *args, **kwargs):
    jaxpr = jax.make_jaxpr(jax.jacfwd(f))(*args, **kwargs)
    _, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals))
    return all(not x for x in used_inputs)

Issue with vmap `optx.least_squares`.

Hi,

I have issues vectorizing the optx.least_squares function (version 0.0.6) when directly vectorized using JAX's vmap function. This behavior occurs unless the sol.state and sol.result fields are removed from the Solution dataclass instance. Perhaps related to this commit? Somehow, vmap does not know that the jaxpr stuff should not be batched (i.e. pytree_node=False).

MWE

In the provided Minimum Working Example (MWE), I attempt to vectorize the least squares optimization using JAX's vmap function. The process involves a quadratic residual function and the Levenberg-Marquardt solver.

The vectorization attempt fails when trying to return the full Solution object (sol) from the least_squares function. However, if only sol.value is returned, the vectorization succeeds. This suggests a compatibility issue between the full Solution instance structure and JAX's vmap operation.

import jax
import jax.numpy as jnp
import optimistix as optx

# Define a simple quadratic residual function
def residual_fn(params, *args):
    return params[0] * jnp.arange(10) ** 2 + params[1] * jnp.arange(10) + params[2]

# Initialize the Levenberg-Marquardt solver
solver = optx.LevenbergMarquardt(rtol=1e-5, atol=1e-7, norm=optx.rms_norm)

# Define the initial parameters
params_init = jnp.array([1.0, 2.0, 3.0])

# Define a function to perform least squares optimization
def least_squares(params_init):
    sol = optx.least_squares(residual_fn, solver, params_init, max_steps=100, throw=False)
    return sol  # throws an error --> returning sol.value does not... 

# Attempt to vectorize the least squares function
vmap_least_squares = jax.vmap(least_squares, in_axes=(0,), out_axes=0)

# Define a batch of initial parameters
batch_params_init = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

# Attempt to perform batched least squares optimization
batch_params_final = vmap_least_squares(batch_params_init)

This produces the error:

Traceback (most recent call last):
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-521a622dc7fd>", line 27, in <module>
    batch_params_final = vmap_least_squares(batch_params_init)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/api.py", line 1258, in vmap_f
    out_flat = batching.batch(
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/linear_util.py", line 206, in call_wrapped
    ans = gen.send(ans)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 638, in _batch_inner
    out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 270, in from_elt
    return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
  File "/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.9/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 1107, in matchaxis
    raise TypeError(f"Output from batched function {x!r} with type "
TypeError: Output from batched function { lambda a:f32[10] b:f32[10]; c:f32[3]. let
    d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] c
    e:f32[] = squeeze[dimensions=(0,)] d
    f:f32[10] = mul e a
    g:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] c
    h:f32[] = squeeze[dimensions=(0,)] g
    i:f32[10] = mul h b
    j:f32[10] = add f i
    k:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] c
    l:f32[] = squeeze[dimensions=(0,)] k
    m:f32[10] = add j l
  in (m,) } with type <class 'jax._src.core.Jaxpr'> is not a valid JAX type

Error in "optimistix/docs/examples /optimise_diffeq.ipynb"

Hi, I was trying to run the example given in "optimistix/docs/examples /optimise_diffeq.ipynb". For some reason I am receiving error "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
__________________________________________________________________ Cell 4 line 9
5 pred_values = batch_solve(parameters, y0s, saveat)
6 return values - pred_values
----> 9 (y0s_0, values_0) = get_data()
10 y0s = jnp.array(y0s_0)
11 values = jnp.array(values_0)

__________________________________________________________________ Cell 4 line 1
9 saveat = dfx.SaveAt(ts=jnp.linspace(0, 30, 20))
10 batch_solve = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(None, 0, None)))
---> 11 values = batch_solve(true_parameters, y0s, saveat)
12 return y0s, values

[... skipping hidden 21 frame]

__________________________________________________________________ Cell 4 line 2
19 t1 = saveat.subs.ts[-1]
20 dt0 = 0.1
---> 21 sol = dfx.diffeqsolve(
22 term,
23 solver,
24 t0,
25 t1,
...
--> 305 raise ValueError("No arrays to thread error on to.")
306 dynamic_x = _error(dynamic_x, pred, index, msgs=msgs, on_error=on_error)
307 return combine(dynamic_x, static_x)

ValueError: No arrays to thread error on to.". I would appreciate if I could get help to fix it.

Can't use Optimistix solvers with `eqx.Module`s and filtered transformations

Thanks very much for this library! Though I understand it's not the primary use case, I'd like to use optimistix with first-order gradient optimizers and standard neural nets to make use of the ability to vectorize optimizers. (Specifically, I'd like to train an ensemble like in equinox, but where each member of the ensemble is paired with a distinct optimizer.)

I run into an error when using optx.GradientDescent with an eqx.Module. Adapting some example code from this repo for a MWE:

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx

N = K = 8
x = jnp.linspace(0, 1, N)[None, ...]
y = x**2

model = eqx.nn.MLP(
  in_size=N,
  out_size=N,
  width_size=K,
  depth=1,
  activation=jax.nn.relu,
  key=jax.random.PRNGKey(42),
)


@eqx.filter_jit
def loss(model, args):
  x, y = args
  pred_y = eqx.filter_vmap(model)(x)
  loss = jnp.mean((pred_y - y) ** 2)
  aux = None
  return loss, aux


optimizer = optx.GradientDescent(learning_rate=1e-1, rtol=1e-4, atol=1e-4)
options = None
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

init = eqx.filter_jit(eqx.Partial(optimizer.init, fn=loss, options=options, f_struct=f_struct,
                      aux_struct=aux_struct, tags=tags))
step = eqx.filter_jit(eqx.Partial(optimizer.step, fn=loss, options=options, tags=tags))
terminate = eqx.filter_jit(eqx.Partial(optimizer.terminate, fn=loss, options=options, tags=tags))
postprocess = eqx.filter_jit(eqx.Partial(optimizer.postprocess, fn=loss, options=options, tags=tags))

state = init(y=model, args=(x, y))
done, result = terminate(y=model, args=(x, y), state=state)

while not done:
  model, state, _ = step(y=model, args=(x, y), state=state)
  done, result = terminate(y=model, args=(x, y), state=state)
  print(f"Evaluating iteration with loss value {loss(model, (x, y))[0]}.")

if result != optx.RESULTS.successful:
  print("Failed!")

model, _, _ = postprocess(
  y=model,
  aux=None,
  args=(x, y),
  state=state,
  result=result,
)
print(f"Found solution with loss value {loss(model, (x, y))[0]}.")

gives me:

TypeError: Value <jax._src.custom_derivatives.custom_jvp object at 0x1022171d0> with type <class 'jax._src.custom_derivatives.custom_jvp'> is not a valid JAX type

at this line:

f_info_struct = jax.eval_shape(lambda: f_info)

which, if I understand correctly, is the result of jax.eval_shape hitting non-arrays. How can I filter for arrays in model, or is there a different recommended usage pattern here?

BestSoFarMinimiser behavior

Not sure if this is a bug or not, but BestSoFarMinimiser appears to not check the last step of the wrapped solver:

solver = optimistix.BestSoFarMinimiser(optimistix.BFGS(rtol=1e-5, atol=1e-5))
ret = optimistix.minimise(lambda x, _: (x - 3.)**2, solver, 0.)
print(ret.value, ret.state.state.y_eval)

0.0 3.0

TypeError

Excited to explore the library as always!

class MiniData(NamedTuple):
    X: ArrayImpl
    Y: ArrayImpl

def loss_fn_per_obs(y, p):
    return jnp.where(y==1.0, -jnp.log(p ), -jnp.log(1-p ))

def fn(params, args):
    P =  jax.nn.sigmoid(args.X @ params)
    losses = jax.vmap(loss_fn_per_obs, in_axes=(0,0))(args.Y, P)
    return jnp.mean(losses)

init_params = jax.random.normal(jax.random.PRNGKey(0), shape=(19,1))
data = MiniData(X=jax.random.normal(jax.random.PRNGKey(1), shape=(100, 19)),
                Y= jax.random.normal(jax.random.PRNGKey(2), shape=(100, 1)))
solver = optimistix.NonlinearCG(rtol=0.01, atol=0.01)
optimistix.minimise(fn=fn, solver=solver, y0 = init_params, args=data, has_aux=False)

I am running into the following type error:

TypeError: linearize() got an unexpected keyword argument 'has_aux'

Pytree inputs for `rtol` and `atol` or custom termination condition?

So I would love to be able to pass in a pytree for both the rtol and atol values, in a similar vein to how you can set individual learning rates for each leaf in optax. This would make a lot of sense for most of my work which has pytree leaves with vastly different parameter scales.

Looking at the termination condition code, it looks like this hasn't been made an option because the values are applied to both the pytree leaf values ('y space') and the loss value ('f space').

From what I can tell there would two ways to get this behavior:

  1. Allow custom termination condition.

    I don't think this is the best solution as the cauchy termination is already wrapped up with input norm function/pytree.

  2. Allow pytree inputs for rtol and atol.

    I think this could be done relatively easily by allowing the f-space and y-space conditions to be individually specified via a tuple like this: (f-space rtol (float), y-space rtol (float or pytree)). This would preserve the present syntax, while also allowing users full freedom over the termination condition.

Anyway maybe there is a better way that I'm missing, but having this functionality is actually somewhat essential for using optimistix in my work in the long run, so let me know your thoughts!

Including user-defined Jacobian

Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports vmap for some time. Currently I am using an external call to scipy.optimize.root together with the multiprocessing library, which is quite slow.

The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?

For reference, this is my MWE in case I am not doing things efficiently:

from jax import jit, jacfwd, vmap, random
import optimistix as optx

def fn(y, b):
    return (y-b)**2

M = 1024
key = random.PRNGKey(42)
key, key_ = random.split(key, 2)

y = random.normal(key, (M,))
b = random.normal(key_, (M,))
sol = optx.root_find(vmap(fn), solver, y, b)

Using OptaxMinimiser results in AttributeError

Using the OptaxMinimiser as solver results in " '_Closure' object has no attribute 'init' " whereas the BFGS solver runs without errors.

The objective function uses a custom equinox pytree.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.