ucl-bug / jaxdf Goto Github PK
View Code? Open in Web Editor NEWA JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
License: GNU Lesser General Public License v3.0
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
License: GNU Lesser General Public License v3.0
It would be nice to have a new FinitedDifferences
-like class based on mpi4jax
, to enable HPC applications.
A good milestone could be to reproduce the shallow water example of the library (original code here)
I would like to be able to control the finite difference scheme used, i.e. forward, backward or central. Depending upon the PDE, we normally use a custom scheme, e.g. advection --> backwards, diffusion ---> central.
Working Demo
I have a working colab notebook to get a feeling for what I mean. See it here.
Proposed Solution
I don't have a solution but somewhere in the param
PyTree I think it is important to specify this (just like the accuracy, order, stepsize, etc).
u = DiscretizationScheme(u_init, domain)
class Params:
method: str = static_field()
stagger: iterable(int) = static_field()
accuracy: int = static_field()
params = Params(method="central", stagger=[0], accuracy=2)
u_grad = gradient(u=u, params=params)
Another possible solution: one could use the FiniteDiffX
package backend for generating the coefficients and kernel if one doesn't specify it. There I recently contributed to be able to specify the FD scheme.
Last solution: Just create a custom operator that does exactly all that I've said before. There is an example in the "custom equation of motion" section which does exactly what I want.
Following the suggestion in #110 , it probably makes sense to generally allow for a custom backend or, if that's too complicated, leverage equinox
and make it the default backend.
Describe the bug
When I run the example given in the readme, it gives me the following error
SignatureError: The argument 'params' must be a keyword argument in . Example: def evaluate(x, *, params): ...
To Reproduce
Steps to reproduce the behavior:
Expected behavior
No error. Computing the gradient at the end.
Desktop (please complete the following information):
Additional context
Add any other context about the problem here.
Would be nice to have some examples of using jaxdf for a few simulation applications, for example:
Describe the bug
Running the paper example, on the latest version of jax 0.4.9
, generates some high-frequency oscillations in the field that were not present before.
To Reproduce
Steps to reproduce the behavior:
example_1_paper
notebookExpected behavior
The new oscillations simply should not be there.
Really there's no reason to use vjp instead of jvp. I'm tempted to call this a bug.
The .dx
attribute only make sense for fields defined on a grid, and can be uniquely found from the shape of OnGrid.params
and the size of the domain.
It is much more natural to define the domain as
domain = Domain(size: tuple)
# Or, potentially
domain = Domain.from_grid(N, dx)
It would be even better to define domains that are not rectangular, for example
domain = Domain() # Abstract domain
class RectangularDomain(Domain):
L: tuple
class SphericalDomain(Domain):
R: float
which then allows to define non standard discretizations. For example, Continuous
can work on arbitrarily shaped domains, FourierSeries
probably only makes sense on RectangularDomain
while something like a SphericalFourierSeries
(see for example s2fft) could be implemented on a SphericalDomain
.
This is clearly a breaking change.
Is there a native way to do a finite difference operator on a multidimensional field, i.e.
Using the current API, I don't see a way to specify the scheme. It could a combination of the stagger
option but I could not find a way to do it. There is a gradient operation but many times we just need a simple difference operator where we can choose which axis.
Note: This might be a problem due to the fact that we use convolutional operators for the FD scheme whereas normally I think of slicing.
Demo
I have a demo colab notebook to showcase what I mean. The equation of motion is a 1D problem but if it were 2D then this API would not work.
Proposed Solution
No specific solution but it might be helpful to have a simple API for this as in many models like the Shallow water and Quasi-geostrophic models, we need this because we have a lot of advection terms.
# current API
u_grad = gradient(u)
du_dx = u_grad.replace_params(u_grad.on_grid()[0])
# preferred API
du_dx = difference(u, axis=0)
Another solution is just to write a custom operator for the difference scheme. The colab notebook that I linked before has an example of this. This is also related to issue #127 and #125.
Is your feature request related to a problem? Please describe.
The current error that is returned to the user when params
is potentially not clear:
SignatureError: The argument 'params' must be a keyword argument in . Example: def evaluate(x, *, params)
Describe the solution you'd like
Make it clear to the user that params
must be defined for operators, and give an example of how that must be done. At the moment, it relies on plum
s signature error handling
One immediate feature that emerged from the chat with @jejjohnson is the ability to work with fields in a way that allows hiding them from the user, or at least not explicitly working with them.
A common pattern for achieving this is given by the following code:
def my_awesome_func(u: jax.ArrayLike):
# Declare fields
N = u.shape
dx = [0.1,] * len(N)
u_field = FourierSeries(u, Domain(N, dx))
# Perform the desired operation using jaxdf
v_field = some_operator(u_field)
# Return a simple jax array
return v_field.on_grid
To simplify the syntax and achieve a cleaner implementation, this pattern can be encapsulated in a decorator, as shown below:
@use_discretization(FourierSeries, dx)
def my_awesome_func(u: jax.ArrayLike):
return some_operator(u_field)
Here, the use_discretization
decorator takes care of packing and unpacking the fields:
def use_discretization(discr_class, dx):
def _decorator(func):
def wrapper(u):
# Declare fields
N = u.shape
dx = [0.1,] * len(N)
u_field = FourierSeries(u, Domain(N, dx))
# Perform the desired operation using jaxdf
v_field = func(u_field)
# Return a simple jax array
return v_field.on_grid
return wrapper
return _decorator
Potential issues and things to work out
dx
in this exampleOnGrid
fields?Describe the bug
A clear and concise description of what the bug is.
the report is "ValueError: diag input must be 1d or 2d"
and the picture of sound map is error with code in the 22th cell.
there are lack some code (from jaxdf.operators import gradient, diag_jacobian, sum_over_dims) in the 16th cell.
To Reproduce
Steps to reproduce the behavior:
Expected behavior
A clear and concise description of what you expected to happen.
no bug
Desktop (please complete the following information):
Additional context
Add any other context about the problem here.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.