Hey, first up another awesome package in the Jax eco-system! I've been meaning to incorporate these kind of solvers in my work for a long time, so thanks for for making it easy 😛. This is partly a discussion post as I am relatively unfamiliar with these algorithms, I have done my best to parse the docs in detail, but feel free to point me to any external resources as I would love to learn more.
Mapping solvers to different leaves
Is there a way that we can map different solvers to each leaf of a pytree? Lets say we know one parameter will be initialised in the smooth bowl of the loss space and can be solved with BFGS
, but the other parameter has a 'noisy' loss topology and is best tackled with a regular optax
optimiser. This is actually quite typical for the sort of optical models I work with, although not super common in general AFAIK.
It is simple to apply each of these algorithms one at a time to each pytree leaf with eqx.partition
and eqx.combine
. This approach works but can't 'jointly' optimise these leaves and would result in redundant calculation of the model gradients, since the grads from each evaluation could be passed to both algorithms.
Now I recognise that a 'joint' approach would pose a problem for algorithms like BFGS
since it would be trying to find the minimum of a dynamic topology that changes as the other leaves are updated throughout the optimisation. I would be curious as to what you think might be the right approach to this kind of problem, maybe there are solvers designed for this sort of problem? If not what approach might you take, I'm very excited about the flexibility and extensibility of this software to be able to build out much better custom solvers for my rather niche set of optimisation problems.
Parameter normalisation during the solve loop
So during a gradient descent loop we commonly need to apply some normalisation/regularisation to our updated parameters to ensure they are 'physical'. An example would be normalising relative spectral weights to have a sum of 1 after the updates have been applied. I am wondering if there is a way to enforce these constraints during the solve. The simplest example case here would be preventing some values from being above some threshold.
I would guess this would likely be possible through a user-defined solver class, that applies the custom regularisation. If something like this is possible, how would it be implemented? From a crude look at the code it looks like this could be done within the step
function of the AbstractGradientDescent
class?
Parameters with large scale variation
So this one is more of an open discussion, rather than a specific question. It's very common for the models I work with to have vastly different scales (everything from 1e10 to 1e-10). This is a problem for these algorithms in general, so I was hoping to get your thoughts on what would be the right way to approach a solution.
There is the 'naive' solution where you apply a scaling to each parameter of the pytree before passing it into the minimisation function, and then inverting the scaling once inside the function. Now this works but is far from what I would consider ideal as it still requires a degree of prior knowledge of the model and sort of just kicks the tunable hyper-parameter from a learning rate into a scaling. Granted this is still going to be generally more robust, but I feel like there is something more elegant... I'm wondering if you have any thoughts or ideas about this!
Anyway thanks again for the excellent software and the help!