Giter VIP home page Giter VIP logo

gpjax's Introduction

GPJax's logo

codecov CodeFactor Documentation Status PyPI version DOI Downloads Slack Invite

Quickstart | Install guide | Documentation | Slack Community

GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. We define a GP prior in GPJax by specifying a mean and kernel function and multiply this by a likelihood function to construct the posterior. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.

Package support

GPJax was created by Thomas Pinder. Today, the maintenance of GPJax is undertaken by Thomas and Daniel Dodd.

We would be delighted to review pull requests (PRs) from new contributors. Before contributing, please read our guide for contributing. If you do not have the capacity to open a PR, or you would like guidance on how best to structure a PR, then please open an issue. For broader discussions on best practices for fitting GPs, technical questions surrounding the mathematics of GPs, or anything else that you feel doesn't quite constitue an issue, please start a discussion thread in our discussion tracker.

We have recently set up a Slack channel where we hope to facilitate discussions around the development of GPJax and broader support for Gaussian process modelling. If you'd like to join the channel, then please follow this invitation link which will take you to the GPJax Slack community.

Supported methods and interfaces

Examples

Guides for customisation

Simple example

This simple regression example aims to illustrate the resemblance of GPJax's API with how we write the mathematics of Gaussian processes.

After importing the necessary dependencies, we'll simulate some data.

import gpjax as gpx
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.example_libraries import optimizers
from jax import jit

key = jr.PRNGKey(123)

x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(50,)).sort().reshape(-1, 1)
y = jnp.sin(x) + jr.normal(key, shape=x.shape)*0.05
training = gpx.Dataset(X=x, y=y)

The function of interest here is sinusoidal, but our observations of it have been perturbed by independent zero-mean Gaussian noise. We aim to utilise a Gaussian process to try and recover this latent function.

We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.

prior = gpx.Prior(kernel = gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints = x.shape[0])

The posterior is then constructed by the product of our prior with our likelihood.

posterior = prior * likelihood

Equipped with the posterior, we proceed to train the model's hyperparameters through gradient-optimisation of the marginal log-likelihood.

We begin by defining a set of initial parameter values through the initialise callable.

params, _, constrainer, unconstrainer = gpx.initialise(posterior)
params = gpx.transform(params, unconstrainer)

Next, we define the marginal log-likelihood, adding Jax's just-in-time (JIT) compilation to accelerate training. Notice that this is the first instance of incorporating data into our model. Model building works this way in principle too, where we first define our prior model, then observe some data and use this data to build a posterior.

mll = jit(posterior.marginal_log_likelihood(training, constrainer, negative=True))

Finally, we utilise Jax's built-in Adam optimiser and run an optimisation loop.

opt_init, opt_update, get_params = optimizers.adam(step_size=0.01)
opt_state = opt_init(params)

def step(i, opt_state):
    params = get_params(opt_state)
    gradients = jax.grad(mll)(params)
    return opt_update(i, gradients, opt_state)

for i in range(100):
    opt_state = step(i, opt_state)

Now that our parameters are optimised, we transform these back to their original constrained space. Using their learned values, we can obtain the posterior distribution of the latent function at novel test points.

final_params = gpx.transform(get_params(opt_state), constrainer)

xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)

latent_distribution = posterior(training, final_params)(xtest)
predictive_distribution = likelihood(latent_distribution, params)

predictive_mean = predictive_distribution.mean()
predictive_stddev = predictive_distribution.stddev()

Installation

Stable version

To install the latest stable version of GPJax run

pip install gpjax

Development version

To install the latest, possibly unstable, version, the following steps should be followed. It is by no means compulsory, but we do advise that you do all of the below inside a virtual environment.

git clone https://github.com/thomaspinder/GPJax.git
cd GPJax
python setup.py develop

We then recommend you check your installation using the supplied unit tests.

python -m pytest tests/

Citing GPJax

If you use GPJax in your research, please cite our JOSS paper. A sample Bibtex file is

@article{Pinder2022,
  doi = {10.21105/joss.04455},
  url = {https://doi.org/10.21105/joss.04455},
  year = {2022},
  publisher = {The Open Journal},
  volume = {7},
  number = {75},
  pages = {4455},
  author = {Thomas Pinder and Daniel Dodd},
  title = {GPJax: A Gaussian Process Framework in JAX},
  journal = {Journal of Open Source Software}
}

gpjax's People

Contributors

daniel-dodd avatar jejjohnson avatar theorashid avatar dfm avatar bodin-e avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.