Giter VIP home page Giter VIP logo

ssm-jax's Introduction

SSM: Bayesian learning and inference for state space models

Integration Tests Unit Tests Documentation Status

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend.

Update: This project has been superceded by DYNAMAX. Check it out!

Example

A quick demonstration of some of the most basic elements of SSM. Check out the example notebooks for more!

from ssm.hmm import GaussianHMM
import jax.random as jr

# create a true HMM model
hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(0))
states, data = hmm.sample(key=jr.PRNGKey(1), num_steps=500, num_samples=5)

# create a test HMM model
test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))

# fit it to our sampled data
log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")

Installation for Development

# use your favorite venv system
conda env create -n ssm_jax python=3.9
conda activate ssm_jax

# in repo root directory...
pip install -r requirements.txt

Project Structure

.
├── docs                      # [documentation]
├── notebooks                 # [example jupyter notebooks]
├── ssm                       # [main code repository]
│   ├── hmm                       # hmm   models
│   ├── factorial_hmm             # factorial hmm models
│   ├── arhmm                     # arhmm models
│   ├── twarhmm                   # twarhmm models
│   ├── lds                       # lds   models
│   ├── slds                      # slds  models
│   ├── inference                 # inference code
│   ├── distributions             # distributions (generally, extensions of tfp distributions)
└── tests                     # [tests]
    ├── [unit tests]              # unit test files mirroring the structure of ssm directory
    |   ...
    └── timing_comparisons        # benchmarking code (including comparisons to SSM_v0)

Documentation

Click here for documentation

ssm-jax's People

Contributors

ahwillia avatar jcostacurta11 avatar matthew9671 avatar schlagercollin avatar slinderman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

ssm-jax's Issues

Hamiltonian Monte Carlo (HMC) for HMM example

SSM's hidden Markov model (HMM) objects expose a function to compute the marginal likelihood of the data, summing over the discrete latent states. This function can be automatically differentiated with jax.grad. Use Tensorflow Probability's Hamiltonian Monte Carlo (HMC) functionality to perform Bayesian inference over HMM parameters, using the marginal likelihood and a prior on parameter values.

whether dynamax support for rSLDS? & where can I find the variance explained by each latent

I noticed that this project has been superceded by the dynamax. However, I cannot find any place where dynamax incorporate SLDS or rSLDS. Will they be included in a future release, or am I overlooking something?

I am also trying to fit my neural data to rSLDS framework. As mentioned in issue #163, is it possible to get the variance explained by each latent so that a reasonable number of states and dimensions could be specified? Thank you so much in advance!

Error: `model` must be convertible to `dict` (saw: DeviceArray)

Dear all,

I am trying to run the example code "GaussianHMM." However, I got an error saying "TypeError: model must be convertible to dict (saw: DeviceArray).".

I searched for this error in the Jax community but could not find a solution. Could you please help me out? Thank you very much!

from ssm.hmm import GaussianHMM
import jax.random as jr

# create a true HMM model
hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(0))
states, data = hmm.sample(key=jr.PRNGKey(1), num_steps=500, num_samples=5)

# create a test HMM model
test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))

# fit it to our sampled data
log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")
Initializing...

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 12
      9 test_hmm = GaussianHMM(num_states=5, num_emission_dims=10, seed=jr.PRNGKey(32))
     11 # fit it to our sampled data
---> 12 log_probs, fitted_model, posteriors = test_hmm.fit(data, method="em")

File ~/tmp/ssm-jax/ssm/utils.py:254, in ensure_has_batch_dim.<locals>.ensure_has_batch_dim_decorator.<locals>.wrapper(*args, **kwargs)
    250         if key in bound_args.arguments and bound_args.arguments[key] is not None:
    251             bound_args.arguments[key] = \
    252                 tree_map(lambda x: x[None, ...], bound_args.arguments[key])
--> 254 return f(**bound_args.arguments)

File ~/tmp/ssm-jax/ssm/hmm/base.py:201, in HMM.fit(self, data, covariates, metadata, method, num_iters, tol, initialization_method, key, verbosity)
    199 if initialization_method is not None:
    200     if verbosity >= Verbosity.LOUD : print("Initializing...")
--> 201     self.initialize(key, data, method=initialization_method)
    202     if verbosity >= Verbosity.LOUD: print("Done.", flush=True)
    204 if method == "em":

File ~/tmp/ssm-jax/ssm/utils.py:254, in ensure_has_batch_dim.<locals>.ensure_has_batch_dim_decorator.<locals>.wrapper(*args, **kwargs)
    250         if key in bound_args.arguments and bound_args.arguments[key] is not None:
    251             bound_args.arguments[key] = \
    252                 tree_map(lambda x: x[None, ...], bound_args.arguments[key])
--> 254 return f(**bound_args.arguments)

File ~/tmp/ssm-jax/ssm/hmm/base.py:132, in HMM.initialize(self, key, data, covariates, metadata, method)
    129 dummy_posteriors = DummyPosterior(one_hot(assignments, self._num_states))
    131 # Do one m-step with the dummy posteriors
--> 132 self._emissions.m_step(data, dummy_posteriors)

File ~/tmp/ssm-jax/ssm/hmm/emissions.py:161, in ExponentialFamilyEmissions.m_step(self, dataset, posteriors, covariates, metadata)
    145 def m_step(self, dataset, posteriors, covariates=None, metadata=None) -> ExponentialFamilyEmissions:
    146     """Update the emissions distribution using an M-step.
    147 
    148     Operates over a batch of data (posterior must have the same batch dim).
   (...)
    159         emissions (ExponentialFamilyEmissions): updated emissions object
    160     """
--> 161     conditional = self._emissions_distribution_class.compute_conditional(
    162         dataset, weights=posteriors.expected_states, prior=self._prior)
    163     self._distribution = self._emissions_distribution_class.from_params(
    164         conditional.mode())
    165     return self

File ~/tmp/ssm-jax/ssm/distributions/expfam.py:98, in ExponentialFamilyDistribution.compute_conditional(cls, data, weights, prior)
     95     stats = tree_map(np.add, stats, prior.natural_parameters)
     97 # Compute the conditional distribution given the stats
---> 98 return cls.compute_conditional_from_stats(stats)

File ~/tmp/ssm-jax/ssm/distributions/expfam.py:75, in ExponentialFamilyDistribution.compute_conditional_from_stats(cls, stats)
     73 @classmethod
     74 def compute_conditional_from_stats(cls, stats):
---> 75     return get_prior(cls).from_natural_parameters(stats)

File ~/tmp/ssm-jax/ssm/distributions/niw.py:69, in NormalInverseWishart.from_natural_parameters(cls, natural_params)
     67 loc = np.einsum("...i,...->...i", s2, 1 / mean_precision)
     68 scale = s3 - np.einsum("...,...i,...j->...ij", mean_precision, loc, loc)
---> 69 return cls(loc, mean_precision, df, scale)

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:474, in JointDistributionNamed.__new__(cls, *args, **kwargs)
    470   model = kwargs.get('model')
    472 if not all(isinstance(d, tf.__internal__.CompositeTensor) or callable(d)
    473            for d in tf.nest.flatten(model)):
--> 474   return _JointDistributionNamed(*args, **kwargs)
    475 return super(JointDistributionNamed, cls).__new__(cls)

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:323, in _JointDistributionNamed.__init__(self, model, batch_ndims, use_vectorized_map, validate_args, experimental_use_kahan_sum, name)
    287 def __init__(self,
    288              model,
    289              batch_ndims=None,
   (...)
    292              experimental_use_kahan_sum=False,
    293              name=None):
    294   """Construct the `JointDistributionNamed` distribution.
    295 
    296   Args:
   (...)
    321       Default value: `None` (i.e., `"JointDistributionNamed"`).
    322   """
--> 323   super(_JointDistributionNamed, self).__init__(
    324       model,
    325       batch_ndims=batch_ndims,
    326       use_vectorized_map=use_vectorized_map,
    327       validate_args=validate_args,
    328       experimental_use_kahan_sum=experimental_use_kahan_sum,
    329       name=name or 'JointDistributionNamed')

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_sequential.py:362, in _JointDistributionSequential.__init__(self, model, batch_ndims, use_vectorized_map, validate_args, experimental_use_kahan_sum, name)
    360 self._model_trackable = model
    361 self._model = self._no_dependency(model)
--> 362 self._build(model)
    364 super(_JointDistributionSequential, self).__init__(
    365     dtype=None,  # Ignored; we'll override.
    366     batch_ndims=batch_ndims,
   (...)
    370     experimental_use_kahan_sum=experimental_use_kahan_sum,
    371     name=name)
    373 # If the model consists entirely of prebuilt distributions with no
    374 # dependencies, cache them directly to avoid a sample call down the road.

File ~/anaconda3/envs/ssm_jax/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/joint_distribution_named.py:334, in _JointDistributionNamed._build(self, model)
    332 """Creates `dist_fn`, `dist_fn_wrapped`, `dist_fn_args`, `dist_fn_name`."""
    333 if not _is_dict_like(model):
--> 334   raise TypeError('`model` must be convertible to `dict` (saw: {}).'.format(
    335       type(model).__name__))
    336 [
    337     self._dist_fn,
    338     self._dist_fn_wrapped,
    339     self._dist_fn_args,
    340     self._dist_fn_name,  # JointDistributionSequential doesn't have this.
    341 ] = _prob_chain_rule_model_flatten(model)

TypeError: `model` must be convertible to `dict` (saw: DeviceArray).

jax tree_multimap deprecated

Hi, I'm excited to start using this codebase.

I followed the "installation for development" instructions in the README, but then from ssm.hmm import GaussianHMM gives me an error about tree_multimap. It looks like jax.tree_util.tree_multimap was removed in jax v0.3.16 (changelog) and the latest release is 0.3.20. It look like tree_multimap was simply replaced by jax.tree_utils.tree_map and that the functionality is the same (jax issue, PR).

In this repo, tree_multimap is only used in ssm.utils.tree_all_equal, so it should be straightforward to remove.

Best,

Jack

Numerical instability PoissonHmm

Hi, I've been trying to fit a PoissonHmm with some simulations made in NEST https://github.com/nest/nest-simulator, due to the length of the simulation and high spike counts I thought using the refactor would optimise the running time with respect to the other version of the SSM where everything works but it can take a while (not too long to get worried). However when using the jax-ssm I encounter numerical instability in the EM update step, the following assertion is raised no matter the number of iterations or even taking a small sample of the data:

assert np.isfinite(lp), "NaNs in marginal log probability"

I was wondering if there is a known limitation with a large number of spike counts or something that I am missing.

Gibbs sampling for Gaussian LDS

LDSs with Gaussian emissions admit a simple Gibbs sampling algorithm: alternate between the following two steps:

  1. Sample the continuous latent states given the parameters and data using LDSPosterior.sample()
  2. Sample the parameters from their conditional distribution given the latent states and data. This will follow the same recipe as the GaussianEmissions.m_step(), but it will use conditional.sample() instead of conditional.mode().

documentation: how can we extract the emission parameters from a model?

In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/notebooks/bernoulli-hmm-example.ipynb, you create an HMM with random bernoulli observation model. How can I extract and plot the underlying nstates x ndims matrix of probabilities? (It's buried behind some TFP class.)

Similarly, in https://github.com/lindermanlab/ssm-jax-refactor/blob/main/notebooks/gaussian-hmm-example.ipynb, how do I extract the observation parameters of the learned model.

Fit HMMs with fixed emissions matrix

Is there a way to do this? Essentially I'd like to pass a flag to fit(..., method="em") that would tell the code to optimize the transitions matrix and keep the emissions fixed (or vice versa!)

ABC imports in Colab

Hey @schlagercollin, I'm running into an issue with the new ABC imports when running in Colab. Is this a Python versioning issue?

TypeError Traceback (most recent call last)
in ()
----> 6 from ssm.lds import GaussianLDS
...
/usr/local/lib/python3.7/dist-packages/ssm/utils.py in ()
101 z2: Sequence[int],
102 K1: Optional[int] = None,
--> 103 K2: Optional[int] = None,
104 ):
105 """
TypeError: 'ABCMeta' object is not subscriptable

Gibbs sampling for HMM

HMMs with exponential family emissions admit a simple Gibbs sampling algorithm: alternate between the following two steps:

  1. Sample the discrete latent states given the parameters and data using HMMPosterior.sample()
  2. Sample the parameters from their conditional distribution given the latent states and data. This will follow the same recipe as the ExponentialFamilyEmissions.m_step(), but it will use conditional.sample() instead of conditional.mode().

Hamiltonian Monte Carlo for Gaussian LDS example

SSM's Gaussian linear dynamical system (LDS) objects expose a function to compute the marginal likelihood of the data, integrating over the continuous latent states. This function can be automatically differentiated with jax.grad. Use Tensorflow Probability's Hamiltonian Monte Carlo (HMC) functionality to perform Bayesian inference over LDS parameters, using the marginal likelihood and a prior on parameter values.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.