Comments (7)
Hmm, so my first question is what kind of behaviour you're seeking, specifically?
Basically, what goes wrong with fixing an rtol, which already handles multiple scales?
from optimistix.
Haha yeah sorry probably should have just provided an example of the desired functionality.
import jax
import jax.numpy as np
import equinox as eqx
import optimistix as optx
import optax
# Set up Model
class Linear(eqx.Module):
m: jax.Array
b: jax.Array
def __init__(self, m, b):
self.m = m
self.b = b
def __call__(self, x):
return self.m * x + self.b
# Simple loss
def loss_fn(model, args):
x, y = args
return np.mean((model(x) - y) ** 2)
# Normal optax optimiser
linear = Linear(np.array(1.0), np.array(0.0))
param_spec = eqx.tree_at(lambda x: (x.m, x.b), linear, ("m", "b"))
optim = optax.multi_transform({"m": optax.adam(1e-3), "b": optax.adam(1e3)}, param_spec)
# Per-leaf atol and rtol
rtol = eqx.tree_at(lambda x: (x.m, x.b), linear, (0.1, 10))
atol = eqx.tree_at(lambda x: (x.m, x.b), linear, (0.1, 10))
# Optimistix minimiser
solver = optx.OptaxMinimiser(
optim,
rtol=(1e-3, rtol), # f-space rtol, y-space rtol
atol=(1e-3, atol), # f-space atol, y-space atol
)
So in this example my termination condition would have a different rtol
and atol
for the loss (f-space) and for each leaf (y-space), as opposed to having the same termination value applied to everything. Does that clarify my question?
Its also possible I have miss-understood something about the how the termination condtion works, so please let me know if thats the case!
from optimistix.
Sure! Sorry, to be clear, I understand the ask, and the fact that Optimistix doesn't support this right now. What I'm trying to better understand (before thinking about a possible solution) is why this kind of mixed-tolerance is a desirable thing to want in the first place.
Typically the reason for having rtol
is so that you can be sure of getting solutions whose accuracy scales linearly with the scale of the problem (rtol
). (And likewise atol
exists to get scale-invariant accuracies.)
So if that isn't sufficient, is it because you want some nonlinear function scale->accuracy? What kind of nonlinear function / why?
from optimistix.
Ah okay, yeah let me explain my reasoning as I'm not well versed on the theory behind all this stuff so I might just not understand the use of the rtol
and atol
values correctly.
We are optimising forwards models with great diversity in parameter scales, on problem with large diversity in likelihoods. Taking a two parameter example we might be trying to find both the position and brightness of a star imaged through a telescope. On-sky position is typically measured using arcseconds of order 1e-3, and brightness is typically measured in photons which can range in values from 1e4 - 1e12. Having a single termination value for both of these parameters is difficult.
The on-sky position measurement is relative to the optics, so if our true values can be arbitrarily close to zero the desired rtol
would need to be somewhat large (or possibly even ignored), and the atol
would want to be ~1e-3.
The brightness of the star however could cover a many orders of magnitude, so there isn't really a concrete atol
value that makes sense (ie If we take 1e-3 to match the position that would be far too small). In this case we would need an rtol
of ~1e-3.
That is essentially the core of my issue, and maybe this should be framed as more of a question - How you you go about devising convergence criteria for a problem like this? Am I thinking about the convergence incorrectly in the first place?
from optimistix.
From what you've said, I think suspect taking atol=1e-3
and rtol=1e-3
, or thereabouts, should be about right. The overall scale is given by atol + rtol * value
, so the rtol
will be negligible for the on-sky position and the atol
will be negligible for the brightness.
For what it's worth, if we were to change this, then I'd be tempted to do this through the norm
instead -- perhaps introduce a separate norm for the y
and f
spaces. Then you can scale each component however might be desired.
from optimistix.
Yeah that was just a small example - In practice we are optimizing over a dozen unique sets of parameters, so trying to find a balance of between every leaf type becomes unwieldy.
I think I'm starting to get my head around the way this works under the hood. Building a robust norm
function seems like it could solve this problem, and possibly also be used as a way to normalize parameter values through a custom solver.
Ultimately all of these questions are also in the context of #20, where parameter scales are just a problem in general so a robust solution would also need to be cognizant of that too. I'll have to look into this more when I have some more time to consider both the problem and solutions more carefully.
Thanks for the info!
from optimistix.
You're welcome! Let us know how it goes -- we can definitely add something to the API if your problem ends up being tricky to implement as-is.
from optimistix.
Related Issues (20)
- Can't use Optimistix solvers with `eqx.Module`s and filtered transformations HOT 2
- BestSoFarMinimiser behavior HOT 1
- correct name of the exception class that Equinox uses for runtime errors HOT 1
- Error in "optimistix/docs/examples /optimise_diffeq.ipynb" HOT 1
- Issue with vmap `optx.least_squares`. HOT 2
- grad of vmap of function which wraps an optax solver occasionally fails HOT 2
- `BestSoFar...` wanted behavior ? HOT 1
- Classical newton methods HOT 6
- Non-finite values in the root function are not handled well HOT 2
- Will constrained optimization be supported? HOT 4
- Behavior of BFGS HOT 2
- pytree output structure mismatch error in backprop during vmap HOT 9
- Incompatibility of least_squares and custom_vjp HOT 2
- Extracting intermediate function values/ losses from the solve HOT 4
- Zero implicit gradients when using `ImplicitAdjoint` with CG solver HOT 4
- Would an exhaustive grid search have a place in `optimistix`? HOT 2
- Using `optimistix` with an `equinox` model HOT 2
- Incompatibility with jax 0.4.27 HOT 1
- Possibly of interest HOT 1
- Unexpected behaviour with JAX version HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from optimistix.