Giter VIP home page Giter VIP logo

kfac-jax's People

Contributors

botev avatar chsigg avatar fabianp avatar hawkinsp avatar hbq1 avatar james-martens avatar joeljennings avatar rchen152 avatar sauravmaheshkar avatar sharadmv avatar superbobry avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

kfac-jax's Issues

TypeError: 'ShapedArray' object is not iterable

Hi,

I tried to run the example code, but the code stops at primal_output = self.bind(*arg_values, **kwargs), and returns the error "TypeError: 'ShapedArray' object is not iterable". Could you please help me to solve this problem? Thanks.

Quickstart example with different NN libraries does not tag correctly

Hey,

I am adapting the quickstart example to equinox to use it in my project. However, it seems only the bias is correctly tagged (with Auto[scale_and_shift_tag_0]); the weight matrix is tagged as 'Orphan'. I also tried it with flax, in which case only the weight matrix is tagged correctly (with Auto[dense_tag_0]), but the bias is tagged as 'Orphan'. And lastly, for pure jax it seems again only to tag the bias correctly.

I am not sure what I am doing wrong here; I introduced minimal changes. In the test script below the four different libraries can be switched between using the lib_type variable in the code. I would appreciate any help.

import haiku as hk
import jax
import jax.numpy as jnp
import kfac_jax
from absl import logging
import sys
import equinox as eqx
import flax
import jax.random as random

logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Hyper parameters
NUM_CLASSES = 10
L2_REG = 1e-3
NUM_BATCHES = 5
NUM_FEATURES = 20
rng = jax.random.PRNGKey(42)


def make_dataset_iterator(batch_size):
  # Dummy dataset, in practice this should be your dataset pipeline
  for _ in range(NUM_BATCHES):
    yield jnp.zeros([batch_size, NUM_FEATURES]), jnp.ones([batch_size], dtype="int32")


def softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):
  """Softmax cross entropy loss."""
  # We assume integer labels
  assert logits.ndim == targets.ndim + 1

  # Tell KFAC-JAX this model represents a classifier
  # See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses
  kfac_jax.register_softmax_cross_entropy_loss(logits, targets)
  log_p = jax.nn.log_softmax(logits, axis=-1)
  return - jax.vmap(lambda x, y: x[y])(log_p, targets)


lib_type = 'hk' # 'hk', 'eqx', 'flax' 'jax'
#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
if lib_type == 'hk':
  def model_fn(x):
    return hk.nets.MLP(
      output_sizes=(50, 50, NUM_CLASSES),
      with_bias=True,
      activation=jax.nn.tanh,
    )(x)

  hk_model = hk.without_apply_rng(hk.transform(model_fn))

  def loss_fn(model_params, model_batch):
    """The loss function to optimize."""
    x, y = model_batch
    logits = hk_model.apply(model_params, x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))

    return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0

  optimizer = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(loss_fn),
    l2_reg=L2_REG,
    value_func_has_aux=False,
    value_func_has_state=False,
    value_func_has_rng=False,
    use_adaptive_learning_rate=True,
    use_adaptive_momentum=True,
    use_adaptive_damping=True,
    initial_damping=1.0,
    multi_device=False,
  )

  input_dataset = make_dataset_iterator(128)
  dummy_images, dummy_labels = next(input_dataset)
  rng, key = jax.random.split(rng)
  params = hk_model.init(key, dummy_images)
  rng, key = jax.random.split(rng)
  opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
elif lib_type == 'eqx':
  class simple_net(eqx.Module):
      net: callable

      def __init__(self, key=None):
        keys = jax.random.split(key,  2)
        self.net = eqx.nn.MLP(NUM_FEATURES, NUM_CLASSES, 100, 0, activation=jax.nn.tanh, key=keys[1])

      def __call__(self, x):
        return self.net(x)

  eqx_model = simple_net(rng)
  params, static = eqx.partition(eqx_model, eqx.is_inexact_array)


  def loss_fn(model_params, static, model_batch):
    """The loss function to optimize."""
    x, y = model_batch
    eqx_model = eqx.combine(model_params, static)
    logits = jax.vmap(eqx_model)(x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))

    return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0
  kfac_loss_fn = lambda params, batch: loss_fn(params, static, batch)

  optimizer = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(kfac_loss_fn),
    l2_reg=L2_REG,
    value_func_has_aux=False,
    value_func_has_state=False,
    value_func_has_rng=False,
    use_adaptive_learning_rate=True,
    use_adaptive_momentum=True,
    use_adaptive_damping=True,
    initial_damping=1.0,
    multi_device=False,
  )

  input_dataset = make_dataset_iterator(128)
  dummy_images, dummy_labels = next(input_dataset)
  rng, key = jax.random.split(rng)
  opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
elif lib_type == 'flax':
  class MLP(flax.linen.Module):

    def setup(self):
      self.dense1 = flax.linen.Dense(32)
      self.dense2 = flax.linen.Dense(NUM_CLASSES)

    def __call__(self, x):
      x = self.dense1(x)
      x = flax.linen.relu(x)
      x = self.dense2(x)
      return x


  flax_model = MLP()
  params = flax_model.init(rng, jnp.zeros([128, NUM_FEATURES]))

  def loss_fn(model_params, model_batch):
    """The loss function to optimize."""
    x, y = model_batch
    logits = flax_model.apply(model_params, x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))

    return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0

  optimizer = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(loss_fn),
    l2_reg=L2_REG,
    value_func_has_aux=False,
    value_func_has_state=False,
    value_func_has_rng=False,
    use_adaptive_learning_rate=True,
    use_adaptive_momentum=True,
    use_adaptive_damping=True,
    initial_damping=1.0,
    multi_device=False,
  )

  input_dataset = make_dataset_iterator(128)
  dummy_images, dummy_labels = next(input_dataset)
  rng, key = jax.random.split(rng)
  opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
elif lib_type == 'jax':
  def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))


  # Initialize all layers for a fully-connected neural network with sizes "sizes"
  def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

  def relu(x):
    return jnp.maximum(0, x)

  def predict(params, x):
    activations = x
    for w, b in params[:-1]:
      outputs = jnp.dot(w, activations) + b
      activations = relu(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits


  params = init_network_params([NUM_FEATURES, 100, NUM_CLASSES], random.PRNGKey(0))


  def loss_fn(model_params, model_batch):
    """The loss function to optimize."""
    x, y = model_batch
    logits = jax.vmap(predict, in_axes=(None, 0))(model_params, x)
    loss = jnp.mean(softmax_cross_entropy(logits, y))

    return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0

  optimizer = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(loss_fn),
    l2_reg=L2_REG,
    value_func_has_aux=False,
    value_func_has_state=False,
    value_func_has_rng=False,
    use_adaptive_learning_rate=True,
    use_adaptive_momentum=True,
    use_adaptive_damping=True,
    initial_damping=1.0,
    multi_device=False,
  )

  input_dataset = make_dataset_iterator(128)
  dummy_images, dummy_labels = next(input_dataset)
  rng, key = jax.random.split(rng)
  opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))


# Training loop
for i, batch in enumerate(input_dataset):
  rng, key = jax.random.split(rng)
  params, opt_state, stats = optimizer.step(
      params, opt_state, key, batch=batch, global_step_int=i)
  print(i, stats)

Using K-FAC with physics-based losses

Hey,

Thank you for the implementation.

From the guide, I saw that I have to register loss functions to be able to use K-FAC.
For my specific case, the loss function is a FEM simulation on the outputs of the network along with some other functions (postprocessing, filtering etc).

Will it be possible to use K-FAC?

Value functions returning state objects not supported

Hi - thank you for the great project!

I'm working with models which carry an internal state, so their forward pass functions return an extra state object. This causes a crash from KFAC-JAX, because that state object is only accounted for at the input, rather than the output.

In optimizer.convert_value_and_grad_to_value_func, the flag has_aux is used to decide whether to return the value function's output directly or take its index-0 element. As the docstring says, this is similar behaviour to jax.grad(), but the flag really refers to any extra output given by the value function, not just an aux dictionary. This hits a snag where the function is called in the Optimizer constructor (optimizer.py line 358), because only value_func_has_aux is considered, so value_func_has_state doesn't cause the index-0 behaviour like I think it should.

I've fixed this locally by changing that call to use has_aux=value_func_has_aux or value_func_has_state - I haven't come across any other problems with state outputs, but I guess they might still exist. This is the patch file I use for the 0.0.3 release:
kfac_jax.txt - I can submit it as a PR if that's helpful.

Add Support for KFAC Optimization in LSTM and GRU Layers

Feature

I kindly request the addition of support for the Kronecker-Factored Approximate Curvature (KFAC) optimization technique in LSTM and GRU layers within the existing KFAC Optimizer. Currently, most of the KFAC Optimizer classes are tailored for linear and 2D convolution layers. Extending its capabilities to encompass RNN layers would be a significant enhancement.

Proposal

The proposal entails integrating KFAC optimization support for LSTM and GRU layers into the KFAC optimizer. This would involve adapting the KFAC Optimizer to calculate the requisite statistics and computation of chain-structured linear Gaussian graphical model for LSTM and GRU layers which I could not find any public implementation of it.

Motivation

LSTM and GRU layers are foundational components in dealing with sequential data, and time-series analysis. I wonder how much KFAC can significantly improve model training using LSTM and GRU layers by providing accurate approximations of the Fisher information matrix? By integrating support for LSTM and GRU layers within the KFAC Optimizer, researchers would gain the ability to apply the KFAC optimization technique to a wider array of models, including reinforcement learning algorithms.

Additional Context

I have full confidence that the repository maintainers, particularly the first author of the paper titled

I appreciate your consideration of this feature request. Thank you.

Unpack Error when using KFAC with block-diagonal for Dense networks

Hi,

I was trying to get the example code in the readme working with the BlockDiagonal approximation. The default simply uses the normal diagonal. However, when I try to define my optimizer like this:

opt = kfac_jax.Optimizer(
    value_and_grad_func=jax.value_and_grad(partial(expected_model_likelihood, l2=0.001)),
    l2_reg=0.001,
    use_adaptive_learning_rate=True,
    use_adaptive_damping=True,
    use_adaptive_momentum=True,
    initial_damping=1.0,
    min_damping= 0.0001,
    layer_tag_to_block_ctor={'generic_tag': kfac_jax.DenseTwoKroneckerFactored},  # Specify the approximation type here
    estimation_mode='ggn_curvature_prop',
    multi_device=False
)

then when I try to use this optimizer I get the following ValueError:

del pmap_axis_name
x, = estimation_data["inputs"]
dy, = estimation_data["outputs_tangent"]
assert utils.first_dim_is_size(batch_size, x, dy)

ValueError: not enough values to unpack (expected 1, got 0)

Corresponding to the curvature update method in class DenseTwoKroneckerFactored (line 1165) of _src.curvature_blocks.py. The estimation data dictionary is filled with the parameters and parameters-tangents, but I do not understand the codebase sufficiently to grasp why the inputs and outputs_tangent keys are not filled.

In this way I cannot get the actual KFAC of this repo working... Are there perhaps some examples that make use of the DenseTwoKroneckerFactored? As far as I can tell all provided examples simply make use of the diagonal Fisher for optimization, not KFAC. But I may be wrong of course.

Quick question on "layer_tag_vjp"

Hey KFAC team,

First of all, thanks a lot for this awesome project and all the hard work!

Got a quick question on the implementation of _layer_tag_vjp function in "tracer.py". The version is 0.0.3.

For the returned function vjp_func, my understanding is that it reads primal and tangent value of the "layer inputs" from previously constructed information, specifically primals_dict and tangents_dict. My questions are:

  1. For "primal_dict", it not only contains info for all the layer input, but also contains info for the input of the whole jaxpr. See here. It seems to me that the latter info is not needed here since we are only retrieving info for layer inputs in vjp_func. So is the latter info (about the input of the whole jaxpr) really necessary here, and why?
  2. For "tangents_dict", it constructed from the aux_vjp. The implementation reads:
    all_tangents = aux_vjp(tangents)
    tangents_dict, inputs_tangents = all_tangents[0], all_tangents[1:]
    inputs_tangents = jax.tree_util.tree_leaves(inputs_tangents)
    tangents_dict.update(zip(processed_jaxpr.jaxpr.invars, inputs_tangents))

Here aux_vjp is the vjp function for forward_aux (See here). Therefore the output of aux_vjp, namely all_tangents should have the same structure as the input of forward_aux. Here forward_aux only has a single argument, and thus all_tangents should be just a tuple with a single element. If that's the case, then inputs_tangents is always empty, and we can simplify the implementation to

tangents_dict, = aux_vjp(tangents)

Am I missing anything here? Or in which case will we have a non-empty inputs_tangents?

KFAC Norm Constraint

Hi,

In the documentation of applying a norm constraint to the update gradient, it says:

norm_constraint: Scalar. If specified, the update is scaled down so that
        its approximate squared Fisher norm ``v^T F v`` is at most the specified
        value. (Note that here ``F`` is the approximate curvature matrix, not
        the exact.)

and the corresponding part of the code:

preconditioned_grads = self.estimator.multiply_inverse(
        state=state.estimator_state,
        parameter_structured_vector=grads,
        identity_weight=self.l2_reg + damping,
        exact_power=self._use_exact_inverses,
        use_cached=self._use_cached_inverses,
        pmap_axis_name=self.pmap_axis_name,
    )
if self._norm_constraint is not None:

      assert not self._use_adaptive_learning_rate
      assert coefficient is not None

      sq_norm_grads = utils.inner_product(preconditioned_grads, grads)

      sq_norm_scaled_grads = sq_norm_grads * coefficient ** 2

      max_coefficient = jnp.sqrt(self._norm_constraint / sq_norm_scaled_grads)
      coefficient = jnp.minimum(max_coefficient, 1)

      preconditioned_grads = utils.scalar_mul(preconditioned_grads, coefficient)

However, as far as I am aware, the preconditioned_grads is F^-1 v, so the sq_norm_grads is actually computing v^T F^-1 v instead of v^T F v as documented. Did I understand it correctly or it is intended to be like so?

Using kfac inside jitted function

Dear All,

in the problem I am working on, I would need to use Optimizer.step inside a broader function that is jitted for performance reasons. Is there some canonical way, how to do this without dramatically hurting performance and/or inducing some error?

Best,
Honza

How to use kfac to train two probabilistic models jointly?

In my application, I need to jointly optimize two probabilistic models. They contribute to two different terms in the final loss function.

I am wondering what would be the recommended pattern of using kfac ?
More specifically, does it make sense to invoke kfac_jax.register_normal_predictive_distribution twice (for the two probabilistic models respectively) ?

Thanks in advance!

Can this be used for Laplace approximation?

In Laplace approximation, the Hessian of the loss function is computed for quadratic approximation. Can this package be used to do a block-diagonal approximation of the Hessian at the minimum? If yes, could you please show (using jax and flax) how to approximate it and define a quadratic approximation of the loss function (which should be something like 1/2 (theta - theta_star)^T H(L)(theta_star) (theta - theta_star), where theta_star is the minimum and H(L) is the Hessian of the loss function)?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.