Giter VIP home page Giter VIP logo

hippox's Introduction

Hippox: High-order Polynomial Projection Operators for JAX

image

What is Hippox?

Hippox provides a simple dataclass for initializing High-order Polynomial Projection Operators (HiPPOs) as parameters in JAX neural network libraries such as Flax and Haiku.

Example

Here is an example of initializing HiPPO parameters inside a Haiku module:

import haiku as hk 
from hippox.main import Hippo

class MyHippoModule(hk.Module):
    def __init__(self, state_size, measure)
        _hippo = Hippo(state_size=state_size, measure=measure)
        _hippo()

        self._lambda_real = hk.get_parameter(
            'lambda_real',
            shape=[state_size,],
            init = _hippo.lambda_initializer('real')
        )
        self._lambda_imag = hk.get_parameter(
            'lambda_imaginary',
            shape=[state_size,],
            init = _hippo.lambda_initializer('imaginary')
        )
        self._state_matrix = self._lambda_real + 1j * self._lambda_imag

        self._input_matrix = hk.get_parameter(
            'input_matrix',
            shape=[state_size, 1],
            init=_hippo.b_initializer()
        )

    def __call__(input, prev_state):
        new_state = self._state_matrix @ prev_state + self._input_matrix @ input
        return new_state

If using a library (such as Equinox) which does not require an initializer function but simply takes JAX ndarrays for parameterization, then you can call the HiPPO matrices directly as a property of the base class after it has been called:

import equinox as eqx
from hippox.main import Hippo

class MyHippoModule(eqx.Module):
    A: jnp.ndarray
    B: jnp.ndarray

    def __init__(self, state_size, measure)
        _hippo = Hippo(state_size=state_size, measure=measure)
        _hippo_params = _hippo()
        
        self.A = _hippo_params.state_matrix
        self.B = _hippo_params.input_matrix

    def __call__(input, prev_state):
        new_state = self.A @ prev_state + self.B @ input
        return new_state

Installation

hippox can be easily installed through PyPi:

pip install hippox

References

Repositories

  1. https://github.com/HazyResearch/state-spaces - Original paper implementations in PyTorch

  2. https://github.com/srush/annotated-s4 - JAX implementation of S4 models (S4, S4D, DSS)

Papers

  1. HiPPO: Recurrent Memory with Optimal Polynomial Projections: https://arxiv.org/abs/2008.07669 - Original paper which introduced HiPPOs

  2. Efficiently Modeling Long Sequences with Structured State Spaces: https://arxiv.org/abs/2111.00396 - S4 paper, introduces normal/diagonal plus low rank decomposition

  3. How to Train Your HiPPO: State Space Models with Generalized Orthogonal Basis Projections: https://arxiv.org/abs/2206.12037 - Generalizes and explains the core principals behind HiPPO

  4. On the Parameterization and Initialization of Diagonal State Space Models: https://arxiv.org/abs/2206.11893 - S4D paper, details and explains the diagonal only parameterization

hippox's People

Contributors

jpgoodale avatar

Stargazers

Steven Morad avatar Yang Guo avatar Kuk Jin Kim avatar

Watchers

 avatar Koen Ponse avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.