Giter VIP home page Giter VIP logo

jaxdf's People

Contributors

astanziola avatar btreeby avatar dependabot[bot] avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

jaxdf's Issues

MPI FiniteDifferences

It would be nice to have a new FinitedDifferences-like class based on mpi4jax, to enable HPC applications.

A good milestone could be to reproduce the shallow water example of the library (original code here)

API for Forwards, Backwards, Central Finite Difference

I would like to be able to control the finite difference scheme used, i.e. forward, backward or central. Depending upon the PDE, we normally use a custom scheme, e.g. advection --> backwards, diffusion ---> central.


Working Demo

I have a working colab notebook to get a feeling for what I mean. See it here.


Proposed Solution

I don't have a solution but somewhere in the param PyTree I think it is important to specify this (just like the accuracy, order, stepsize, etc).

u = DiscretizationScheme(u_init, domain)

class Params:
    method: str = static_field()
    stagger: iterable(int) = static_field()
    accuracy: int = static_field()

params = Params(method="central", stagger=[0], accuracy=2)

u_grad = gradient(u=u, params=params)

Another possible solution: one could use the FiniteDiffX package backend for generating the coefficients and kernel if one doesn't specify it. There I recently contributed to be able to specify the FD scheme.

Last solution: Just create a custom operator that does exactly all that I've said before. There is an example in the "custom equation of motion" section which does exactly what I want.

Consider moving to equinox modules

Following the suggestion in #110 , it probably makes sense to generally allow for a custom backend or, if that's too complicated, leverage equinox and make it the default backend.

Running the example on the README does not work.

Describe the bug
When I run the example given in the readme, it gives me the following error

SignatureError: The argument 'params' must be a keyword argument in . Example: def evaluate(x, *, params): ...

To Reproduce
Steps to reproduce the behavior:

  1. Execute the code example given in the readme

Expected behavior
No error. Computing the gradient at the end.

Desktop (please complete the following information):

  • OS: macOS Ventura 13.3.1(a)
  • Version: python 3.11.2

Additional context
Add any other context about the problem here.

Code up some examples

Would be nice to have some examples of using jaxdf for a few simulation applications, for example:

  • Non-linear heat equation
  • Electrical Impedance Tomography forward model
  • Electrical Impedance Tomography solved using some iterative method
  • Reaction-diffusion equations
  • Fluid dynamics
  • ...

Wrong results on the paper example

Describe the bug
Running the paper example, on the latest version of jax 0.4.9, generates some high-frequency oscillations in the field that were not present before.

To Reproduce
Steps to reproduce the behavior:

  1. run the example_1_paper notebook
  2. Plot the solution field for the helmholtz equation

Expected behavior
The new oscillations simply should not be there.

Move `Domain.dx` to `OnGrid.dx`

The .dx attribute only make sense for fields defined on a grid, and can be uniquely found from the shape of OnGrid.params and the size of the domain.

It is much more natural to define the domain as

domain = Domain(size: tuple)

# Or, potentially
domain = Domain.from_grid(N, dx)

It would be even better to define domains that are not rectangular, for example

domain = Domain()  # Abstract domain

class RectangularDomain(Domain):
  L: tuple

class SphericalDomain(Domain):
  R: float

which then allows to define non standard discretizations. For example, Continuous can work on arbitrarily shaped domains, FourierSeries probably only makes sense on RectangularDomain while something like a SphericalFourierSeries (see for example s2fft) could be implemented on a SphericalDomain.

This is clearly a breaking change.

Simple "Difference" Operator

Is there a native way to do a finite difference operator on a multidimensional field, i.e. $\partial_x \vec{\boldsymbol{u}}, \partial_y \vec{\boldsymbol{u}}, \partial_z \vec{\boldsymbol{u}}$, ....

Using the current API, I don't see a way to specify the scheme. It could a combination of the stagger option but I could not find a way to do it. There is a gradient operation but many times we just need a simple difference operator where we can choose which axis.

Note: This might be a problem due to the fact that we use convolutional operators for the FD scheme whereas normally I think of slicing.


Demo

I have a demo colab notebook to showcase what I mean. The equation of motion is a 1D problem but if it were 2D then this API would not work.


Proposed Solution

No specific solution but it might be helpful to have a simple API for this as in many models like the Shallow water and Quasi-geostrophic models, we need this because we have a lot of advection terms.

# current API
u_grad = gradient(u)
du_dx = u_grad.replace_params(u_grad.on_grid()[0])

# preferred API
du_dx = difference(u, axis=0)

Another solution is just to write a custom operator for the difference scheme. The colab notebook that I linked before has an example of this. This is also related to issue #127 and #125.

Throw a better error when `params` is missing from an operator

Is your feature request related to a problem? Please describe.
The current error that is returned to the user when params is potentially not clear:

SignatureError: The argument 'params' must be a keyword argument in . Example: def evaluate(x, *, params)

Describe the solution you'd like
Make it clear to the user that params must be defined for operators, and give an example of how that must be done. At the moment, it relies on plums signature error handling

Make a wrapper to hide jaxdf computations

One immediate feature that emerged from the chat with @jejjohnson is the ability to work with fields in a way that allows hiding them from the user, or at least not explicitly working with them.

A common pattern for achieving this is given by the following code:

def my_awesome_func(u: jax.ArrayLike):
  # Declare fields
  N = u.shape
  dx = [0.1,] * len(N)
  u_field = FourierSeries(u, Domain(N, dx))
  
  # Perform the desired operation using jaxdf
  v_field = some_operator(u_field)
  
  # Return a simple jax array
  return v_field.on_grid

To simplify the syntax and achieve a cleaner implementation, this pattern can be encapsulated in a decorator, as shown below:

@use_discretization(FourierSeries, dx)
def my_awesome_func(u: jax.ArrayLike):
  return some_operator(u_field)

Here, the use_discretization decorator takes care of packing and unpacking the fields:

def use_discretization(discr_class, dx):
  def _decorator(func):

    def wrapper(u):
      # Declare fields
      N = u.shape
      dx = [0.1,] * len(N)
      u_field = FourierSeries(u, Domain(N, dx))
      
      # Perform the desired operation using jaxdf
      v_field = func(u_field)
  
      # Return a simple jax array
      return v_field.on_grid
   
   return wrapper
return _decorator

Potential issues and things to work out

  • How to deal with multiple input fields
  • How to pass generic parameters, i.e. generalize dx in this example
  • Does this only make sense for OnGrid fields?

there are bug in tutorial example

Describe the bug
A clear and concise description of what the bug is.
the report is "ValueError: diag input must be 1d or 2d"
and the picture of sound map is error with code in the 22th cell.
there are lack some code (from jaxdf.operators import gradient, diag_jacobian, sum_over_dims) in the 16th cell.
To Reproduce
Steps to reproduce the behavior:

  1. just doing as the tutorial example url ( https://ucl-bug.github.io/jaxdf/notebooks/helmholtz_pinn/)

Expected behavior
A clear and concise description of what you expected to happen.
no bug
Desktop (please complete the following information):

  • OS: [e.g. Ubuntu]
  • Version [e.g. 18.04]

Additional context
Add any other context about the problem here.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.