Comments (10)
Thank you for your work, @rlouf!
from blackjax.
thank you very much for the fix ❤️
from blackjax.
Thank you for the bug report and sorry for the inconvenience! I managed to reproduce your example, will investigate. @junpenglao this happens for nuts as well.
from blackjax.
i just tried nuts on the current main branch. i'll try to find time tonight and check if nuts was working on c6f75e9.
from blackjax.
Thanks for the feedback - yeah we need to add a test with monte carlo central limit theorem for this.
from blackjax.
i am pretty but not 100% sure that the example code is correct. i also didn't spend much time trying to optimize the parameter.
nuts seems to work on v1 (8209172) in the sense that the sample mean is close to zero and the sample variance is kind of close to unity (1.17 but this might be due to bad parameters for NUTS). in commit cfeffb1, i can notice different results compared to v1: the sample variance is now 0.33691323
import numpy as np
import jax
import jax.numpy as jnp
import blackjax.nuts as nuts
import matplotlib.pyplot as plt
potential = lambda x: -jax.scipy.stats.norm.logpdf(x, loc=0.0, scale=1.0).squeeze()
initial_position = np.array([1.0,])
initial_state = nuts.new_state(initial_position, potential)
initial_state
step_size=0.1
params = nuts.NUTSParameters(
step_size=step_size,
inv_mass_matrix = 1 * jnp.ones_like(initial_position)
)
nuts_kernel = nuts.kernel(potential, params)
nuts_kernel = jax.jit(nuts_kernel)
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, nuts_kernel, initial_state, 50_000)
samples = states.position.block_until_ready()
print(np.mean(samples, axis=0))
print(np.var(samples, axis=0))
plt.plot(samples)
plt.show()
commit | 8209172 | fe6ff69 | cfeffb1 |
---|---|---|---|
output | [0.0027857] [1.1750876] | [0.0027857] [1.1750876] | [0.0067384] [0.33691323] |
from blackjax.
I'll track this one down, it might be the reason why adaptations fails for the variance of a gaussian target in #44. It must be the change in generate_proposal
since the bug affects both HMC and NUTS.
I will add the extra tests you were talking about @junpenglao with the bug fix.
from blackjax.
I have found a bug which may explain what you observe (I will try to run the corrected code later tonight).
In the init
function for the proposal, position
is passed to the kinetic energy instead of momentum
. The fact that we can pass both is a remnant of the code on which BlackJAX is based, where I tried to implement the SoftAbs metric. I will also remove that possibility for now.
from blackjax.
The first commit on #47 fixes it. I will do a little refactor, add an extra test like @junpenglao suggested and will release a patch.
from blackjax.
I corrected the bug and pushed a patch in 0.2.1 now available on PyPi. Thank you again for reporting!
from blackjax.
Related Issues (20)
- BlackJAX Paper HOT 3
- CHMC from A Family of MCMC Methods on Implicitly Defined Manifolds
- 👋 Blackjax Meeting -
- Removing high-level API classes HOT 4
- 👋 Blackjax Meeting -
- test_chees_adaptation fail with jax 0.4.26 HOT 2
- DeprecationWarning from jax 0.4.27 HOT 2
- Blackjax has an implicit dependence on "jax>=0.4.25" and "jaxlib>=0.4.25" HOT 2
- window_adaptation excessive memory usage HOT 2
- Numerical test `test_chees_adaptation` fails on `aarch64-linux` HOT 4
- Improvements to `run_inference_algorithm`
- Adjusted MCLMC
- Generalize tuning algorithm for adjusted MCLMC to other samplers
- Add Omelyan integrator
- Add some metadata to integrators and export coefficients HOT 3
- Nightly build is failing
- 👋 Blackjax Meeting -
- Ability to pass low-rank mass matrices HOT 1
- smc inner_kernel_tuning forces stateless mcmc_parameters / vmap over shared parameters HOT 6
- Current requirements allows for jax >=0.4.25, but new jnp.clip argument requires jax 0.4.27
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from blackjax.