Giter VIP home page Giter VIP logo

pmwd's People

Contributors

dsjamieson avatar eelregit avatar modichirag avatar yucheng-zhang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pmwd's Issues

odeint is slow

I was experimenting with some time tests and find that odeint to calculate the growth functions is quite slow.
I have tried to hack and replace it with rk4 integration in the growth function itself which seems to be much faster.

    ode_jit = jit(ode)
    def rk4_ode_jit(carry, t):
        y, t_prev = carry
        h = t - t_prev
	k1 = ode_jit(y, t_prev, cosmo)
	k2 = ode_jit(y + h * k1 / 2, t_prev + h / 2, cosmo)
        k3 = ode_jit(y + h * k2 / 2, t_prev + h / 2, cosmo)
	k4 = ode_jit(y + h * k3, t, cosmo)
        y = y + 1.0 / 6.0 * h * (k1 + 2 * k2 + 2 * k3 + k4)
        return (y, t), y

    (yf, _), G = scan(rk4_ode_jit, (G_ic, lna[0]), lna)

Then I do time tests for 64^3 simulation wherein I pass the cosmology parameters, initial modes as input and calculate time for different outputs (just doing boltzmann solve vs boltzmann + LPT).

@jit
def simulate_boltz(modes, omegam, conf):
    '''Evaluate growth & tranfer function with odeint                                                                                                                                                                                                                                                                         
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann(cosmo)
    mesh = None
    return mesh, cosmo

@jit
def simulate_boltz_rk4(modes, omegam, conf):
    '''Evaluate growth & tranfer function with custom rk4                                                                                                                                                                                                                                                                     
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann_rk4(cosmo)
    mesh = None
    return mesh, cosmo

@jit
def simulate(modes, omegam, conf):
    '''Run LPT simulation with evaluating growth & tranfer function with odeint                                                                                                                                                                                                                                               
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann(cosmo)
    ptcl, obsvbl = lpt(modes, cosmo)
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo

@jit
def simulate_rk4(modes, omegam, conf):
    '''Run LPT simulation with evaluating growth & tranfer function with custom rk4                                                                                                                                                                                                                                           
    '''
    cosmo = SimpleLCDM(conf, Omega_m=omegam)
    cosmo = boltzmann_rk4(cosmo)
    ptcl, obsvbl = lpt(modes, cosmo)
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo


@jit
def simulate_nbody(modes, cosmo):
    '''Run LPT simulation without evaluating growth & tranfer function                                                                                                                                                                                                                                                        
    '''
    ptcl, obsvbl = lpt(modes, cosmo)
    conf = cosmo.conf
    dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
    mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
    return mesh, cosmo

The time taken for each of these is

Time taken for boltzmann: 0.5971660375595093
Time taken for boltzmann rk4: 0.007928729057312012
Time taken for LPT: 0.0041596412658691405
Time taken for simulation (Boltzmann + LPT): 0.463437557220459
Time taken for simulation rk4 (Boltzmann + LPT): 0.04284675121307373

rk4 seems to be much faster than using odeint to generate growth rate.

If what I am doing in running the simulations is sensible and the timing numbers portray an accurate picture,
then we should figure a better way (jaxified) to code this?

I have attached the full script as txt file (copy paste in pmwd/pmwd folder, convert to py and it should run)
test_growth.txt

A_s vs sigma_8 parameterization

This is tangentially related to Issue 6

We decided to parameterize cosmology in terms of $A_s$ instead of $\sigma_8$ since we were of the opinion that $A_s$ is more independent of other cosmology parameters than $\sigma_8$, and hence sampling the former might be easier during inference.
However in my experiments, I am finding the case to be different.

I set up a toy problem of sampling $A_s$, $\Omega_m$ and the initial conditions for a mock dark matter density field data.
L=100, N=32, 3step PM forward model, noise=shot noise

Here are the posteriors in $A_s$- $\Omega_m$ plane, and corresponding generated posterior in $\sigma_8$ - $\Omega_m$ plane where I estimated $\sigma_8$ for given posterior samples. This is 8000 samples without thinning.
image

The posterior seems to be much harder in $A_s$- $\Omega_m$ parameterization. In fact most of my chains are not even burning.

Covariance matrix for $A_s$- $\Omega_m$ is

[[ 0.00113 -0.01357]
 [-0.01357 0.17.1 ]]

with condition number ~2600
and
covariance matrix for $\sigma_8$- $\Omega_m$ is

[[ 0.00113 -0.00046]
 [-0.00046  0.00032]]

with condition number ~12.

Now one can maybe scale down $A_s$ by a factor of 10 to make it same order as $\Omega_m$. That might help a little.
But even then the posterior shows that there is a multiplicative degeneracy (banana shape), which makes sense if you think about it and also look at the code-

    Plin= (
        0.32 * cosmo.A_s * cosmo.k_pivot * _safe_power(k / cosmo.k_pivot, cosmo.n_s)
        * (jnp.pi * (conf.c / conf.H_0)**2 / cosmo.Omega_m * T)**2
        * D**2

i.e. $P_{lin} \propto A_s / \Omega_m$
It will not completely go away with $\sigma_8$ parameterization, but the posterior says it might be better constrained.

Finally, though this was for dark matter example, I suspect it will hold for all LSS (galaxies), weak lensing will have a different form but I don't know yet which parameterization is better there.

But for sampling with galaxy clustering, I am now suspecting $\sigma_8$ parameterization might be better. Is there any way we can keep both and let the user decide?
Thoughts, opinions, comments? Does this make sense?

Outer JIT compilation time could be optimized

In this example:
https://gist.github.com/EiffL/8e46d261e5d52cd28ca81e233fef9b04

It takes 3 mins for the first evaluation of the model to run, but just a few seconds in the second run.

@modichirag has also been able to check that the compilation time is a function of the number of steps. This would indicate that the code is building an overly complex computational graph including explicitly each step of the nbody.

I suspect this is due to using a python for loop in the nbody function. Probably things would improve a lot if it were replaced with a lax.scan

A_s -\sigma_8 utiliy function

Should we include a utility function to switch back and forth between $A_s$ and $\sigma_8$, maybe at the configuration level?

Often it is the case that we know the cosmology in terms of $\sigma_8$ and not $A_s$, in which case it might be a good option to have a way of letting the user specify config in terms of $\sigma_8$ while the conversion happens under the hood.

In the similar vein, once we generate samples (during inference) in terms of $A_s$, have a way to convert them to $\sigma_8$.

Upload videos made by Yucheng

Particles align on the initial grid after evolving forward and then backward in time.

time_evolution.mp4

We optimize the initial conditions by gradient descent to make some interesting pattern after projection.

pmwd_optim.mp4

Multi-host distribution

In fantastic news, after over a year of waiting and checking every few months if it was working yet, it looks like finally it's possible to instantiate a distributed XLA runtime in Jax, which means.... Native access to NCCL collectives and composable parallelisation with pmap and xmap!!!

Demo of how to allocate 16 GPUs accross 4 nodes on Perlmutter here: https://github.com/EiffL/jax-gpu-cluster-demo

I'll be testing these things out and documenting my finding in this issue. Maybe won't be directly useful at first but at some point down the line we want to be able to run very large sims easily.

ValueError when using @jax.jit decorator for objective function with pmwd

Hi, I am writing a sampling code using pmwd. The objective function in my code (actually negative log probability) is frequently called, so I attempted to add @jax.jit to speed it up. However, I always got the an error like ValueError: invalid literal for int() with base 10: 'int16[8,3]' when executing the nbody function.
Here is an illustrative example:

from pmwd import (
    Configuration, Cosmology,
    SimpleLCDM,
    boltzmann,
    white_noise, linear_modes,
    lpt,
    nbody,
    scatter,
) 

import jax
import jax.numpy as jnp
from jax import jit

def model(modes, cosmo, conf):
    modes = linear_modes(modes, cosmo, conf, None, False)
    ptcl, obsvbl = lpt(modes, cosmo, conf)
    ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf)
    dens = scatter(ptcl, conf)
    return dens

#@jit #uncomment this, the error occurs!
def obj(modes, cosmo, conf):
    dens = model(modes, cosmo, conf)
    return jnp.sum(dens)

conf = Configuration(ptcl_spacing=4, ptcl_grid_shape=(2,)*3, lpt_order=1, float_dtype=jnp.float64, \
                     a_start=0.01, a_nbody_maxstep=1)

p0 = jnp.array([0.3, 0.8])
cosmo = Cosmology.from_sigma8(conf, Omega_m=0.3, sigma8=0.8, n_s=0.96, Omega_b=0.05, h=0.7)
cosmo = boltzmann(cosmo, conf)

obj_func = lambda z: obj(z, cosmo, conf)

vng = jax.value_and_grad(obj_func, argnums=(0))
data_ = jnp.array([[[1,1], [1,1]], [[1, 1], [1, 1]]], dtype=jnp.float64)

vng(data_)

And I got the following error for the script above when enabling @jax.jit:

(jax-test) [user@cluster ~]$ python jit_issue.py 
Traceback (most recent call last):
  File "/home/user/jit_issue.py", line 39, in <module>
    vng(data_)
  File "/home/user/jit_issue.py", line 34, in <lambda>
    obj_func = lambda z: obj(z, cosmo, conf)
  File "/home/user/jit_issue.py", line 24, in obj
    dens = model(modes, cosmo, conf)
  File "/home/user/jit_issue.py", line 18, in model
    ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: invalid literal for int() with base 10: 'int16[8,3]'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/jit_issue.py", line 39, in <module>
    vng(data_)
  File "/home/user/jit_issue.py", line 34, in <lambda>
    obj_func = lambda z: obj(z, cosmo, conf)
  File "/home/user/work/pmwd/pmwd/tree_util.py", line 97, in tree_unflatten
    return cls(**dict(zip(children_names, children)),
  File "<string>", line 9, in __init__
  File "/home/user/work/pmwd/pmwd/particles.py", line 72, in __post_init__
    else jnp.asarray(value, dtype=dtype))
  File "/home/user/.conda/envs/jax-test/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2206, in asarray
    return array(a, dtype=dtype, copy=False, order=order)  # type: ignore
  File "/home/user/.conda/envs/jax-test/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2152, in array
    out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)  # type: ignore[arg-type]
ValueError: invalid literal for int() with base 10: 'int16[8,3]'

I am using jax-0.4.23 [cpu] & pmwd-0.1.dev122+g1e1c634.d20240303 (GitHub commit 1e1c634). Is this an expected behavior?


Mar 6th Update: I attempt to run the script above in an environment with jax-gpu installation, and the error persists. I think it is not a CPU-version-only problem.

the initial step size in jax odeint could be nan

The odeint function in jax.experimental.ode determines the initial step size internally.
However, the returned value could be nan when the initial derivative fun(t0, y0) is
much smaller than the initial value y0. This is caused by the internal algorithm which
evaluates the function at a value ~y0/fun(t0, y0), which could be far beyond the valid range
of the function and thus returns nan.

MLP growth function

With Epod, we now have MLP emulators for growth function that are quite fast.
Do we want to include them in official pmwd code, and if so, how? Ideally one would like to keep a way of having both, MLP and odeint the way it is currently implemented, with a way of switching back and forth between them.

I have a sample version that I am currently using, I will create a different branch to push it as an example and we can discuss that further.

Python 3.7 support?

I assume there is reason for not enabling python 3.7 in the setup.py ^^ But colab is currently on Python 3.7. I just tried to run the example code there and it crashes with that error:

[/usr/local/lib/python3.7/dist-packages/pmwd/boltzmann.py](https://localhost:8080/#) in growth_integ(cosmo, conf)
    199     ), axis=1)
    200 
--> 201     return cosmo.replace(growth=growth)
    202 
    203 

[/usr/local/lib/python3.7/dist-packages/pmwd/tree_util.py](https://localhost:8080/#) in replace(self, **changes)
    127     def replace(self, **changes):
    128         """Create a new object of the same type, replacing fields with changes."""
--> 129         return dataclasses.replace(self, **changes)
    130 
    131     cls.replace = replace

[/usr/lib/python3.7/dataclasses.py](https://localhost:8080/#) in replace(*args, **changes)
   1270         if f.name not in changes:
   1271             if f._field_type is _FIELD_INITVAR:
-> 1272                 raise ValueError(f"InitVar {f.name!r} "
   1273                                  'must be specified with replace()')
   1274             changes[f.name] = getattr(obj, f.name)

ValueError: InitVar 'Omega_k_fixed' must be specified with replace()

I assume maybe there is a new functionality in python 3.8 that makes this work, and not in 3.7? If that's so, I'd be happy to try to implement a workaround. Being able to run things on Colab is super important.

Jax bug in linear_modes

On the master branch, this line causes the following error:

  File "/home/mattho/git/ltu-cmass/cmass/nbody/pmwd.py", line 74, in run_density
    ic = linear_modes(wn, pmcosmo, pmconf)
ValueError: the `static_argnums` argument to `jax.checkpoint` / `jax.remat` can only take integer values greater than or equal to `-len(args)` and less than `len(args)`, but got (4,)

It seems @eelregit pointed this out as a jax bug, but it hasn't been resolved yet.

Could there be a work around temporarily pushed to master? As this is currently breaking the master branch.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.