eelregit / pmwd Goto Github PK
View Code? Open in Web Editor NEWDifferentiable Cosmological Forward Model
License: BSD 3-Clause "New" or "Revised" License
Differentiable Cosmological Forward Model
License: BSD 3-Clause "New" or "Revised" License
I was experimenting with some time tests and find that odeint to calculate the growth functions is quite slow.
I have tried to hack and replace it with rk4 integration in the growth function itself which seems to be much faster.
ode_jit = jit(ode)
def rk4_ode_jit(carry, t):
y, t_prev = carry
h = t - t_prev
k1 = ode_jit(y, t_prev, cosmo)
k2 = ode_jit(y + h * k1 / 2, t_prev + h / 2, cosmo)
k3 = ode_jit(y + h * k2 / 2, t_prev + h / 2, cosmo)
k4 = ode_jit(y + h * k3, t, cosmo)
y = y + 1.0 / 6.0 * h * (k1 + 2 * k2 + 2 * k3 + k4)
return (y, t), y
(yf, _), G = scan(rk4_ode_jit, (G_ic, lna[0]), lna)
Then I do time tests for 64^3 simulation wherein I pass the cosmology parameters, initial modes as input and calculate time for different outputs (just doing boltzmann solve vs boltzmann + LPT).
@jit
def simulate_boltz(modes, omegam, conf):
'''Evaluate growth & tranfer function with odeint
'''
cosmo = SimpleLCDM(conf, Omega_m=omegam)
cosmo = boltzmann(cosmo)
mesh = None
return mesh, cosmo
@jit
def simulate_boltz_rk4(modes, omegam, conf):
'''Evaluate growth & tranfer function with custom rk4
'''
cosmo = SimpleLCDM(conf, Omega_m=omegam)
cosmo = boltzmann_rk4(cosmo)
mesh = None
return mesh, cosmo
@jit
def simulate(modes, omegam, conf):
'''Run LPT simulation with evaluating growth & tranfer function with odeint
'''
cosmo = SimpleLCDM(conf, Omega_m=omegam)
cosmo = boltzmann(cosmo)
ptcl, obsvbl = lpt(modes, cosmo)
dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
return mesh, cosmo
@jit
def simulate_rk4(modes, omegam, conf):
'''Run LPT simulation with evaluating growth & tranfer function with custom rk4
'''
cosmo = SimpleLCDM(conf, Omega_m=omegam)
cosmo = boltzmann_rk4(cosmo)
ptcl, obsvbl = lpt(modes, cosmo)
dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
return mesh, cosmo
@jit
def simulate_nbody(modes, cosmo):
'''Run LPT simulation without evaluating growth & tranfer function
'''
ptcl, obsvbl = lpt(modes, cosmo)
conf = cosmo.conf
dens = jnp.zeros(conf.mesh_shape, dtype=conf.float_dtype)
mesh = scatter(ptcl, dens, 1., conf.cell_size, conf.chunk_size)
return mesh, cosmo
The time taken for each of these is
Time taken for boltzmann: 0.5971660375595093
Time taken for boltzmann rk4: 0.007928729057312012
Time taken for LPT: 0.0041596412658691405
Time taken for simulation (Boltzmann + LPT): 0.463437557220459
Time taken for simulation rk4 (Boltzmann + LPT): 0.04284675121307373
rk4 seems to be much faster than using odeint to generate growth rate.
If what I am doing in running the simulations is sensible and the timing numbers portray an accurate picture,
then we should figure a better way (jaxified) to code this?
I have attached the full script as txt file (copy paste in pmwd/pmwd folder, convert to py and it should run)
test_growth.txt
This is tangentially related to Issue 6
We decided to parameterize cosmology in terms of
However in my experiments, I am finding the case to be different.
I set up a toy problem of sampling
L=100, N=32, 3step PM forward model, noise=shot noise
Here are the posteriors in
The posterior seems to be much harder in
Covariance matrix for
[[ 0.00113 -0.01357]
[-0.01357 0.17.1 ]]
with condition number ~2600
and
covariance matrix for
[[ 0.00113 -0.00046]
[-0.00046 0.00032]]
with condition number ~12.
Now one can maybe scale down
But even then the posterior shows that there is a multiplicative degeneracy (banana shape), which makes sense if you think about it and also look at the code-
Plin= (
0.32 * cosmo.A_s * cosmo.k_pivot * _safe_power(k / cosmo.k_pivot, cosmo.n_s)
* (jnp.pi * (conf.c / conf.H_0)**2 / cosmo.Omega_m * T)**2
* D**2
i.e.
It will not completely go away with
Finally, though this was for dark matter example, I suspect it will hold for all LSS (galaxies), weak lensing will have a different form but I don't know yet which parameterization is better there.
But for sampling with galaxy clustering, I am now suspecting
Thoughts, opinions, comments? Does this make sense?
In this example:
https://gist.github.com/EiffL/8e46d261e5d52cd28ca81e233fef9b04
It takes 3 mins for the first evaluation of the model to run, but just a few seconds in the second run.
@modichirag has also been able to check that the compilation time is a function of the number of steps. This would indicate that the code is building an overly complex computational graph including explicitly each step of the nbody.
I suspect this is due to using a python for loop in the nbody function. Probably things would improve a lot if it were replaced with a lax.scan
Should we include a utility function to switch back and forth between
Often it is the case that we know the cosmology in terms of
In the similar vein, once we generate samples (during inference) in terms of
Particles align on the initial grid after evolving forward and then backward in time.
We optimize the initial conditions by gradient descent to make some interesting pattern after projection.
In fantastic news, after over a year of waiting and checking every few months if it was working yet, it looks like finally it's possible to instantiate a distributed XLA runtime in Jax, which means.... Native access to NCCL collectives and composable parallelisation with pmap and xmap!!!
Demo of how to allocate 16 GPUs accross 4 nodes on Perlmutter here: https://github.com/EiffL/jax-gpu-cluster-demo
I'll be testing these things out and documenting my finding in this issue. Maybe won't be directly useful at first but at some point down the line we want to be able to run very large sims easily.
Hi, I am writing a sampling code using pmwd. The objective function in my code (actually negative log probability) is frequently called, so I attempted to add @jax.jit
to speed it up. However, I always got the an error like ValueError: invalid literal for int() with base 10: 'int16[8,3]'
when executing the nbody
function.
Here is an illustrative example:
from pmwd import (
Configuration, Cosmology,
SimpleLCDM,
boltzmann,
white_noise, linear_modes,
lpt,
nbody,
scatter,
)
import jax
import jax.numpy as jnp
from jax import jit
def model(modes, cosmo, conf):
modes = linear_modes(modes, cosmo, conf, None, False)
ptcl, obsvbl = lpt(modes, cosmo, conf)
ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf)
dens = scatter(ptcl, conf)
return dens
#@jit #uncomment this, the error occurs!
def obj(modes, cosmo, conf):
dens = model(modes, cosmo, conf)
return jnp.sum(dens)
conf = Configuration(ptcl_spacing=4, ptcl_grid_shape=(2,)*3, lpt_order=1, float_dtype=jnp.float64, \
a_start=0.01, a_nbody_maxstep=1)
p0 = jnp.array([0.3, 0.8])
cosmo = Cosmology.from_sigma8(conf, Omega_m=0.3, sigma8=0.8, n_s=0.96, Omega_b=0.05, h=0.7)
cosmo = boltzmann(cosmo, conf)
obj_func = lambda z: obj(z, cosmo, conf)
vng = jax.value_and_grad(obj_func, argnums=(0))
data_ = jnp.array([[[1,1], [1,1]], [[1, 1], [1, 1]]], dtype=jnp.float64)
vng(data_)
And I got the following error for the script above when enabling @jax.jit
:
(jax-test) [user@cluster ~]$ python jit_issue.py
Traceback (most recent call last):
File "/home/user/jit_issue.py", line 39, in <module>
vng(data_)
File "/home/user/jit_issue.py", line 34, in <lambda>
obj_func = lambda z: obj(z, cosmo, conf)
File "/home/user/jit_issue.py", line 24, in obj
dens = model(modes, cosmo, conf)
File "/home/user/jit_issue.py", line 18, in model
ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: invalid literal for int() with base 10: 'int16[8,3]'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/user/jit_issue.py", line 39, in <module>
vng(data_)
File "/home/user/jit_issue.py", line 34, in <lambda>
obj_func = lambda z: obj(z, cosmo, conf)
File "/home/user/work/pmwd/pmwd/tree_util.py", line 97, in tree_unflatten
return cls(**dict(zip(children_names, children)),
File "<string>", line 9, in __init__
File "/home/user/work/pmwd/pmwd/particles.py", line 72, in __post_init__
else jnp.asarray(value, dtype=dtype))
File "/home/user/.conda/envs/jax-test/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2206, in asarray
return array(a, dtype=dtype, copy=False, order=order) # type: ignore
File "/home/user/.conda/envs/jax-test/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2152, in array
out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False) # type: ignore[arg-type]
ValueError: invalid literal for int() with base 10: 'int16[8,3]'
I am using jax-0.4.23 [cpu] & pmwd-0.1.dev122+g1e1c634.d20240303 (GitHub commit 1e1c634). Is this an expected behavior?
Mar 6th Update: I attempt to run the script above in an environment with jax-gpu installation, and the error persists. I think it is not a CPU-version-only problem.
The odeint
function in jax.experimental.ode
determines the initial step size internally.
However, the returned value could be nan
when the initial derivative fun(t0, y0)
is
much smaller than the initial value y0
. This is caused by the internal algorithm which
evaluates the function at a value ~y0/fun(t0, y0), which could be far beyond the valid range
of the function and thus returns nan
.
With Epod, we now have MLP emulators for growth function that are quite fast.
Do we want to include them in official pmwd code, and if so, how? Ideally one would like to keep a way of having both, MLP and odeint the way it is currently implemented, with a way of switching back and forth between them.
I have a sample version that I am currently using, I will create a different branch to push it as an example and we can discuss that further.
I assume there is reason for not enabling python 3.7 in the setup.py ^^ But colab is currently on Python 3.7. I just tried to run the example code there and it crashes with that error:
[/usr/local/lib/python3.7/dist-packages/pmwd/boltzmann.py](https://localhost:8080/#) in growth_integ(cosmo, conf)
199 ), axis=1)
200
--> 201 return cosmo.replace(growth=growth)
202
203
[/usr/local/lib/python3.7/dist-packages/pmwd/tree_util.py](https://localhost:8080/#) in replace(self, **changes)
127 def replace(self, **changes):
128 """Create a new object of the same type, replacing fields with changes."""
--> 129 return dataclasses.replace(self, **changes)
130
131 cls.replace = replace
[/usr/lib/python3.7/dataclasses.py](https://localhost:8080/#) in replace(*args, **changes)
1270 if f.name not in changes:
1271 if f._field_type is _FIELD_INITVAR:
-> 1272 raise ValueError(f"InitVar {f.name!r} "
1273 'must be specified with replace()')
1274 changes[f.name] = getattr(obj, f.name)
ValueError: InitVar 'Omega_k_fixed' must be specified with replace()
I assume maybe there is a new functionality in python 3.8 that makes this work, and not in 3.7? If that's so, I'd be happy to try to implement a workaround. Being able to run things on Colab is super important.
On the master branch, this line causes the following error:
File "/home/mattho/git/ltu-cmass/cmass/nbody/pmwd.py", line 74, in run_density
ic = linear_modes(wn, pmcosmo, pmconf)
ValueError: the `static_argnums` argument to `jax.checkpoint` / `jax.remat` can only take integer values greater than or equal to `-len(args)` and less than `len(args)`, but got (4,)
It seems @eelregit pointed this out as a jax bug, but it hasn't been resolved yet.
Could there be a work around temporarily pushed to master? As this is currently breaking the master branch.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.