Giter VIP home page Giter VIP logo

laplax's Introduction

Laplax Logo

Laplax

What is laplax?

The laplax package aims to provide a performant, minimal, and practical implementation of Laplace approximation techniques in jax. This package is designed to support a wide range of scientific libraries, initially focusing on compatibility with popular neural network libraries such as equinox, flax.linen, and flax.nnx. Our goal is to create a flexible tool for both practical applications and research, enabling rapid iteration and comparison of new approaches.

Design Philosophy

The development of laplax is guided by the following principles:

  • Minimal Dependencies: The package only depends on jax, ensuring compatibility and ease of integration.

  • Matrix-Vector Product Focus: The core of our implementation revolves around efficient matrix-vector products. By passing around callables, we maintain a loose coupling between components, allowing for easy interaction with various other packages, including linear operator libraries in jax.

  • Performance and Practicality: We prioritize a performant and minimal implementation that serves practical needs. The package offers a simple API for basic use cases while primarily serving as a reference implementation for researchers to compare new methods or iterate quickly over experiments.

  • PyTree-Centric Structure: Internally, the package is structured around PyTrees. This design choice allows us to defer materialization until necessary, optimizing performance and memory usage.

Roadmap and Contributions

We're developing this package in public, and discussions about the roadmap and feature priorities are structured in the Issues section. If you're interested in contributing or want to see what's planned for the future, please check them out.

laplax's People

Contributors

bmucsanyi avatar 2bys avatar

Stargazers

Eteph avatar Philipp Hennig avatar Tim Weiland avatar Frank avatar  avatar  avatar

Watchers

 avatar

Forkers

2bys

laplax's Issues

Roadmap for Version 0.1 - Basic Laplace Approximation Functionality

Possible Roadmap

This issue outlines possible key features and components planned for the first version of the laplax package. The focus is on providing only the foundational elements necessary for Laplace approximation in JAX.

1. Basic Laplace Function

  • Function Signature: laplace(model_fn, params, loss_fn, data) -> (model_fn, params, get_cov_scale)
  • Description: This function will be the core entry point for performing the Laplace approximation. It will take a model function (model_fn), parameters (params), a loss function (loss_fn), and a dataset (data). It returns the modified model function, updated parameters, and a function to compute covariance scaling.

2. Covariance Scaling Function (get_cov_scale)

  • Function Signature: get_cov_scale(prior, ...) -> cov_scale_mv
  • Description: This function will accept parameters such as a prior and return a callable for the covariance-square-root-matrix-vector product (cov_scale_mv). This is essential for calibration and evaluation of the model's uncertainty.

3. GGN Approximations

We aim to support the following Generalized Gauss-Newton (GGN) approximations:

  • Full GGN: No approximation.
  • Diagonal GGN: Provides a diagonal approximation.
  • KFAC (Kronecker-Factored Approximate Curvature): A block-diagonal approximation, where the blocks are given by two Kronecker products.
  • Low-Rank: A low-rank approximation method.

Each method will include:

  • A specific matrix-vector callable.
  • An inverse operation that returns the inverse matrix plus a prior-vector callable.

4. Weight-space uncertainty (cov_scale_mv) to output-space uncertainty (out_cov_scale_mv):

Initial support will include:

  • Linearization
  • Monte Carlo Sampling

Special consideration is needed for classification tasks, particularly to derive uncertainty for the probit space.

5. Calibration and Evaluation

  • Metrics Support: We will support a small set of predefined metric functions.
  • Custom Metric Functions: Users can provide custom metric functions that take inputs such as map_prediction, mean_prediction, cov_prediction, and true_target.
  • Evaluation Process: These functions will be applied over a specified set of validation or test data points.
  • Calibration: Initially, we will support grid search as a method for calibration, requiring the metric to be a single, minimizable function.

Happy to discuss these thoughts below.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.