Comments (2)
@dfm mentioned using scan to replace fori_loop should help reducing the compilation
from flowmc.
Did some more testing on MALA.py compilation time. I think the main compilation time overhead comes from Jax compiling derivative of the log likelihood multiple times.
As an example, here is the mala kernel for one proposal.
...
key1, key2 = jax.random.split(rng_key)
proposal = position + dt * d_logpdf(position)
proposal += jnp.sqrt(2 * dt) * jax.random.normal(key1, shape=position.shape)
ratio = logpdf(proposal) - logpdf(position)
ratio -= ((position - proposal - dt * d_logpdf(proposal)) ** 2 / (4 * dt)).sum()
ratio += ((proposal - position - dt * d_logpdf(position)) ** 2 / (4 * dt)).sum()
proposal_log_prob = logpdf(proposal)
log_uniform = jnp.log(jax.random.uniform(key2))
do_accept = log_uniform < ratio
position = jnp.where(do_accept, proposal, position)
log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
return position, log_prob, do_accept
I tried this example with a gravitational wave likelihood, where one compilation of d_logpdf takes around 70s.
If I have jit the entire kernel, then the compilation time is around 300 seconds.
This seems to indicate Jax does not use the cached version of d_logpdf when jitting the kernel. Even assuming 'proposal' and 'position' somehow trigger recompilation, the compilation time should still be smaller.
Another possibility is Jax unfold the entire mala_kernel
computation graph without considering logpdf
and d_logpdf
are used multiple times hence share the same graph.
from flowmc.
Related Issues (20)
- `jax.interpreters.pxla` has no attribute `ShardedDeviceArray` HOT 2
- Image of the function in tutorial HOT 1
- Ensemble training of normalizing flow
- Question about integrating with bayeux HOT 9
- Sampling from arrays
- Get rid of random_key_set
- Clean up parameter names HOT 1
- Question regarding the data for the log-likelihood HOT 3
- Use scan to reduce NF compilation time
- Making sampler composable
- Put training loop into NF class
- TypeError: unsupported operand type(s) for *: `dict` and `dict` in MALA.py after flowMC-v3.0.0 release HOT 4
- Implement optimization strategy
- Add probability floor to normalizing flow model
- [Fixed bug, but not in release] UnboundLocalError: local variable 'best_state' referenced before assignment HOT 2
- Update examples
- Lower precision training HOT 1
- Why do we have to pass data two times? HOT 1
- Refine strategy interface
- Implement flow matching
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flowmc.