Comments (9)
This is a great request. ccing @jvdillon who leads development of tfp.mcmc.
from probability.
Agreed this would be nice to have. Let me ask around and Ill report back.
from probability.
Thanks! I look forward to that! It would be great if we could have it soon!
from probability.
Acknowledged.
In the meantime, checkout help(tfp.mcmc.RandomWalkMetropolis)
. This has a fully functional example. And you can find more examples in the unit tests.
from probability.
And it would be nice to have a builtin function for calculating auto correlation time.
from probability.
I tried the following code provided in the unit tests. My environment is python3.6 with the latest nightly version of both tf and tfp.
import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp
tfd = tf.contrib.distributions
dtype = np.float32
target = tfd.Normal(loc=dtype(0.0), scale=dtype(1.0))
samples, _ = tfp.mcmc.sample_chain(
num_results=1000,
current_state=dtype(1.0),
kernel=tfp.mcmc.RandomWalkMetropolis(
target.log_prob,
new_state_fn=tfp.mcmc.random_walk_uniform_fn(scale=2.),seed=42),
num_burnin_steps=500,
parallel_iterations=1
) # For determinism.
sample_mean = tf.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(tf.reduce_mean(tf.squared_difference(samples, sample_mean),axis=0))
[sample_mean_, sample_std_] = sess.run([sample_mean, sample_std])
However what I get is
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-29-faad2d5672e2> in <module>()
8 new_state_fn=tfp.mcmc.random_walk_uniform_fn(scale=2.),seed=42),
9 num_burnin_steps=500,
---> 10 parallel_iterations=1
11 ) # For determinism.
12
~/anaconda3/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/sample.py in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, parallel_iterations, name)
247 dtype=tf.int32), # num_steps
248 initializer=[current_state, previous_kernel_results],
--> 249 parallel_iterations=parallel_iterations)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/functional_ops.py in scan(fn, elems, initializer, parallel_iterations, back_prop, swap_memory, infer_shape, name)
618 parallel_iterations=parallel_iterations,
619 back_prop=back_prop, swap_memory=swap_memory,
--> 620 maximum_iterations=n)
621
622 results_flat = [r.stack() for r in r_a]
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations)
3206 if loop_context.outer_context is None:
3207 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
-> 3208 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
3209 if maximum_iterations is not None:
3210 return result[1]
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants)
2944 with ops.get_default_graph()._lock: # pylint: disable=protected-access
2945 original_body_result, exit_vars = self._BuildLoop(
-> 2946 pred, body, original_loop_vars, loop_vars, shape_invariants)
2947 finally:
2948 self.Exit()
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants)
2881 flat_sequence=vars_for_body_with_tensor_arrays)
2882 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
-> 2883 body_result = body(*packed_vars_for_body)
2884 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
2885 if not nest.is_sequence(body_result):
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in <lambda>(i, lv)
3182 cond = lambda i, lv: ( # pylint: disable=g-long-lambda
3183 math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
-> 3184 body = lambda i, lv: (i + 1, orig_body(*lv))
3185
3186 if context.executing_eagerly():
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/functional_ops.py in compute(i, a_flat, tas)
607 packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
608 packed_a = output_pack(a_flat)
--> 609 a_out = fn(packed_a, packed_elems)
610 nest.assert_same_structure(
611 elems if initializer is None else initializer, a_out)
~/anaconda3/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/sample.py in _scan_body(args_list, num_steps)
235 previous_kernel_results,
236 ],
--> 237 parallel_iterations=parallel_iterations)[1:] # Lop off `it_`.
238
239 if previous_kernel_results is None:
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations)
3206 if loop_context.outer_context is None:
3207 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
-> 3208 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
3209 if maximum_iterations is not None:
3210 return result[1]
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in BuildLoop(self, pred, body, loop_vars, shape_invariants)
2944 with ops.get_default_graph()._lock: # pylint: disable=protected-access
2945 original_body_result, exit_vars = self._BuildLoop(
-> 2946 pred, body, original_loop_vars, loop_vars, shape_invariants)
2947 finally:
2948 self.Exit()
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants)
2903 # during this comparison, because inputs are typically lists and
2904 # outputs of the body are typically tuples.
-> 2905 nest.assert_same_structure(list(packed_vars_for_body), list(body_result))
2906
2907 # Store body_result to keep track of TensorArrays returned by body
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/nest.py in assert_same_structure(nest1, nest2, check_types)
181 their substructures. Only possible if `check_types` is `True`.
182 """
--> 183 _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
184
185
TypeError: The two structures don't have the same nested structure.
First structure: type=list str=[<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity:0' shape=() dtype=int32>, <tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_1:0' shape=() dtype=float32>, MetropolisHastingsKernelResults(accepted_results=UncalibratedRandomWalkResults(log_acceptance_correction=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_2:0' shape=() dtype=float32>, target_log_prob=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_3:0' shape=() dtype=float32>), is_accepted=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_4:0' shape=() dtype=bool>, log_accept_ratio=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_5:0' shape=() dtype=float32>, proposed_state=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_6:0' shape=() dtype=float32>, proposed_results=UncalibratedRandomWalkResults(log_acceptance_correction=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_7:0' shape=() dtype=float32>, target_log_prob=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_8:0' shape=() dtype=float32>))]
Second structure: type=list str=[<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/add:0' shape=() dtype=int32>, <tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/choose_next_state/Select:0' shape=() dtype=float32>, MetropolisHastingsKernelResults(accepted_results=[<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/choose_inner_results/Select:0' shape=() dtype=float32>, <tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/choose_inner_results/Select_1:0' shape=() dtype=float32>], is_accepted=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/Less:0' shape=() dtype=bool>, log_accept_ratio=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/compute_log_accept_ratio/Sum:0' shape=() dtype=float32>, proposed_state=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/rwm_one_step/random_walk_uniform_fn/random_uniform:0' shape=() dtype=float32>, proposed_results=UncalibratedRandomWalkResults(log_acceptance_correction=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/rwm_one_step/zeros:0' shape=() dtype=float32>, target_log_prob=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/rwm_one_step/Normal/log_prob/sub:0' shape=() dtype=float32>))]
More specifically: The two namedtuples don't have the same sequence type. First structure type=UncalibratedRandomWalkResults str=UncalibratedRandomWalkResults(log_acceptance_correction=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_2:0' shape=() dtype=float32>, target_log_prob=<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/Identity_3:0' shape=() dtype=float32>) has type UncalibratedRandomWalkResults, while second structure type=list str=[<tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/choose_inner_results/Select:0' shape=() dtype=float32>, <tf.Tensor 'mcmc_sample_chain_11/scan/while/while/mh_one_step/choose_inner_results/Select_1:0' shape=() dtype=float32>] has type list
Any ideas why?
from probability.
I am getting the same error under py3.
from probability.
I tried the following code provided in the unit tests. My environment is python3.6 with the latest nightly version of both tf and tfp.
Thanks for bringing this to our attention. Can you raise a separate GitHub issue?
from probability.
Just in case anyone wonders, the issue raised was fixed in:
#22
As for using MCMC, I wonder, does this example help make this more clear?
https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Eight_Schools.ipynb
from probability.
Related Issues (20)
- Dirichlet distribution sampling issue when jit_compile=True HOT 1
- AttributeError: 'SymbolicTensor' object has no attribute 'log_prob' when exporting train signature with `IndependentNormal` layer HOT 1
- Add Poisson quantile
- Computing log_prob for tfd.Sample() with a different number of samples
- TruncatedCauchy gives wrong results sometimes
- `_parameter_properties` is not implemented for `LinearGaussianStateSpaceModel`
- tensorflow 2.16.1 breaks tensorflow-probability with Keras `3.0` API HOT 3
- `LinearGaussianStateSpaceModel` filtering initial state is incorrect
- Piecewise distribution
- Keras not accepting character `/` from build_factored_surrogate_posterior HOT 4
- A bug in Linear_Mixed_Effects_Models.ipynb
- Conditional input with multiple flows HOT 1
- mlx backend HOT 1
- Can't jit PoissonLogNormalQuadratureCompound log_prob
- autobnn error HOT 2
- Addition of "location" type parameter in the Gamma distribution HOT 2
- Unexpected Symbolic tensor in Tensorflow Probability tensor_coercible object (mixture layer)
- TFP JAX: The transition kernel drastically decreases speed.
- jax.dtypes.prng_key gives `AttributeError: module 'jax.dtypes' has no attribute 'prng_key'` HOT 1
- MultivariateNormalTriL Layer appears to be incompatible with tf.keras in tf 2.16.1 and tfp 0.24 HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from probability.