Giter VIP home page Giter VIP logo

gpslds's Introduction

gpSLDS

This repository contains an implementation of Gaussian Process Switching Linear Dynamical Systems (gpSLDS), described in the paper here (insert link).

Repo structure

gpslds/                         Source code for gpSLDS model implementation.
    em.py                           Implements variational EM and contains the main gpSLDS fitting function.
    initialization.py               Functions for initializing model parameters.
    kernels.py                      GP kernel functions, including our smoothly switching linear kernel.
    likelihoods.py                  Gaussian and Poisson observation models.
    quadrature.py                   Quadrature object for approximating kernel expectations.
    simulate_data.py                Helper functions for sampling from the model.
    transition.py                   Defines GP object for model fitting.
    utils.py                        Variety of helper functions.
data/                           Code and data files for main synthetic data example.
    fit_plds.py                     Script for fitting Poisson LDS to initialize Poisson Process observation model parameters.
    generate_synthetic_data.py      Script for generating synthetic data.
    synthetic_data.pkl              Pickle file containing synthetic data.
    synthetic_plds_emissions.pkl    Pickle file containing initial observation model parameters for synthetic data.
synthetic_data_demo.ipynb       Demo notebook fitting gpSLDS to synthetic data.

Data format

To use the gpSLDS on your own data, you will need to ensure that you have:

  • A JAX array ys_binned of shape (n_trials, n_timesteps, n_output_dims). To process data in effectively continuous-time, n_timesteps should represent the number of time bins at a small discretization step relative to the data sampling rate. We assume that data has been zero-padded in the case of varying length trials.
  • A JAX array t_mask of shape (n_trials, n_timesteps). This is 1 for observed timesteps and 0 for unobserved timesteps.
  • A JAX array trial_mask of shape (n_trials, n_timesteps). This is 1 for timesteps in an observed trial and 0 for a zero-padded timestep.
  • (Optional) A JAX array inputs of shape (n_trials, n_timesteps, n_input_dims) consisting of external stimuli.

For an example, please see synthetic_data_demo.ipynb which demonstrates data formatting and model fitting on a synthetic example.

How to run

The current recommended way to run this code is by using a NVIDIA A100 GPU. The fastest way to get this running is by using Google colab with an A100 GPU runtime, which is demonstrated in synthetic_data_demo.ipynb.

gpslds's People

Contributors

amberhu8 avatar

Stargazers

Xulu Sun avatar

Watchers

Scott Linderman avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.