Giter VIP home page Giter VIP logo

flax's Introduction

logo

Flax: A neural network library and ecosystem for JAX designed for flexibility

Build coverage

Overview | Quick install | What does Flax look like? | Documentation

๐Ÿ“ฃ NEW: Check out the NNX API!

This README is a very short intro. To learn everything you need to know about Flax, refer to our full documentation.

Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.

Flax is being used by a growing community of hundreds of folks in various Alphabet research departments for their daily work, as well as a growing community of open source projects.

The Flax team's mission is to serve the growing JAX neural network research ecosystem -- both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads. We are in the process of moving some remaining internal design docs and conversation threads to GitHub discussions, issues and pull requests. We hope to increasingly engage with the needs and clarifications of the broader ecosystem. Please let us know how we can help!

Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at [email protected].

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

  • Neural network API (flax.linen): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

  • Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device

  • Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

  • Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b

Quick install

You will need Python 3.6 or later, and a working JAX installation (with or without GPU support - refer to the instructions). For a CPU-only version of JAX:

pip install --upgrade pip # To support manylinux2010 wheels.
pip install --upgrade jax jaxlib # CPU-only

Then, install Flax from PyPi:

pip install flax

To upgrade to the latest version of Flax, you can use:

pip install --upgrade git+https://github.com/google/flax.git

To install some additional dependencies (like matplotlib) that are required but not included by some dependencies, you can use:

pip install "flax[all]"

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our guides and developer notes.

from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)

๐Ÿค— Hugging Face

In-detail examples to train and evaluate a variety of Flax models for Natural Language Processing, Computer Vision, and Speech Recognition are actively maintained in the ๐Ÿค— Transformers repository.

As of October 2021, the 19 most-used Transformer architectures are supported in Flax and over 5000 pretrained checkpoints in Flax have been uploaded to the ๐Ÿค— Hub.

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.9.0},
  year = {2024},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

flax's People

Contributors

8bitmp3 avatar adarob avatar alexeyg avatar ameya98 avatar andsteing avatar avital avatar bastings avatar bohnetbd avatar cgarciae avatar chiamp avatar cpgaffney1 avatar danielsuo avatar dependabot[bot] avatar gmittal avatar hawkinsp avatar ivyzx avatar jheek avatar joaogui1 avatar levskaya avatar makora9143 avatar marcvanzee avatar marvin182 avatar mbz avatar melissatan avatar mohitreddy1996 avatar neilgirdhar avatar romanngg avatar wenscarl avatar wrzadkow avatar yashk2810 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

flax's Issues

Polyak Averaging Params_ema initiailzation

Hi, I was applying the changes in your HOWTO to add Polyak averaging and there seems to be some code missing, specifically params_ema is not initialized and so the line
optimizer, params_ema = train_step(optimizer, params_ema, batch)
causes an UnboundLocalError (params_ema referenced before assignment).

TypeError: iteration over a 0-d array

When I run the following code, I get a TypeError: iteration over a 0-d array error from Jax. The code looks correct to me, and I don't understand where this error is coming from or how to fix the issue.

import jax
import jax.numpy as jnp
from flax import nn, optim

class CNN(nn.Module):
    def apply(self, x):
        x = nn.Conv(x, features=32, kernel_size=(3, 3))
        x = x.reshape((x.shape[0], -1))
        v = nn.Dense(x, features=1)
        return v

@jax.jit
def train_step(optimizer, observations, returns):
    def loss_fn(model):
        values_pred = model(observations)
        return jnp.square(values_pred - returns).mean()
    optimizer, _ = optimizer.optimize(loss_fn)
    return optimizer

batch_size = 7
input_shape = (batch_size, 32, 32, 5)

key = jax.random.PRNGKey(0)

_, model = CNN.create_by_shape(key, [(input_shape, jnp.float32)])
optimizer = optim.Adam(learning_rate=0.01).create(model)

observations = jax.random.normal(key, input_shape)
returns = jax.random.normal(key, (batch_size, ))

optimizer = train_step(optimizer, observations, returns)

produces the following stack trace

Traceback (most recent call last):
  File "test.py", line 33, in <module>
    optimizer = train_step(optimizer, observations, returns)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
    name=flat_fun.__name__)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py", line 605, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 449, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 223, in memoized_fun
    ans = call(fun, *args)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 466, in _xla_callable
    jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 152, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "test.py", line 18, in train_step
    optimizer, l, logits = optimizer.optimize(loss_fn)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/optim.py", line 266, in optimize
    loss, aux, grad = self.compute_gradients(loss_fn)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/optim.py", line 250, in compute_gradients
    (loss, aux), grad = grad_fn(self.target)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 413, in value_and_grad_f
    ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 1293, in vjp
    out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/ad.py", line 111, in vjp
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/ad.py", line 98, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 337, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 156, in call_wrapped
    ans = gen.send(ans)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api_util.py", line 72, in flatten_fun_nokwargs2
    ans, aux = yield py_args, {}
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py", line 300, in __iter__
    return iter(self.aval._iter(self))
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/lax/lax.py", line 1497, in _iter
    raise TypeError("iteration over a 0-d array")  # same as numpy error
TypeError: iteration over a 0-d array

Any help is greatly appreciated!

Reorganizing optim into directory structure?

Hello again! At the Princeton office, we work on, among other things, optimization algorithms for deep learning. We're interested in using flax and wanted to add some other well-known algorithms. Would you guys be open to reorganizing optim.py into a directory a la pytorch? Happy to submit a PR if so!

Usually, this would accompany a PR, but being new around here, wanted to understand how (if at all) you wanted to reorganize.

One possibility: All subclasses of OptimizerDef (except MultiOptimizer, which appears to have a circular dependency with OptimizerDef) live in their own files (e.g., Momentum, GradientDescent)

Cannot run seq2seq train example :(

Hi!
First, thank you very much for this project!!!
Second, I am just simply trying to run this script to learn how to use GRU/LSTM: https://github.com/google/flax/blob/prerelease/examples/seq2seq/train.py

and I have encountered 2 issues:

  1. batch_metrics is missing : Although is not an important issue, I just commented it out, but maybe you would like to do something about it :)

  2. Error that is out of my skills:

Traceback (most recent call last):
  File "/.../FlaxExample.py", line 288, in <module>
    app.run(main)
  File "/.../anaconda3/lib/python3.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/.../anaconda3/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/.../FlaxExample.py", line 284, in main
    _ = train_model()
  File "/.../FlaxExample.py", line 271, in train_model
    model = create_model()
  File "/.../FlaxExample.py", line 174, in create_model
    ((1, get_max_output_len(), vocab_size), jnp.float32)])
  File "/.../anaconda3/lib/python3.7/site-packages/flax/nn/base.py", line 183, in wrapper
    return super_fn(*args, **kwargs)
  File "/.../anaconda3/lib/python3.7/site-packages/flax/nn/base.py", line 392, in init_by_shape
    return jax_utils.partial_eval_by_shape(lazy_init, input_specs)
  File "/.../anaconda3/lib/python3.7/site-packages/flax/jax_utils.py", line 94, in partial_eval_by_shape
    output_shapes = jax.eval_shape(lazy_fn, *input_structs)
  File "/.../anaconda3/lib/python3.7/site-packages/jax/api.py", line 2023, in eval_shape
    out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
  File "/.../anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 256, in abstract_eval_fun
    instantiate=True)
  File "/.../anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 337, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/.../anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 152, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/.../anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 152, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/.../anaconda3/lib/python3.7/site-packages/flax/jax_utils.py", line 88, in lazy_fn
    master = leaves[0]._trace.master  # pylint: disable=protected-access
  File "/.../anaconda3/lib/python3.7/site-packages/jax/core.py", line 365, in __getattr__
    attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute '_trace'

I believe I have the latest Jax and Flax versions:
Flax: 0.0.1a0
Jax : 0.1.58

Thanks for your help :)

Investigate and fix cause of warnings defined in pytest.ini

With pytest.ini defined, all the warnings captured during a test execution are reported as exceptions.

Although there are a few warnings which are independent of flax's implementation (for e.g: DeprecationWarning and UserWarning captured due to importing tensorflow.compat.v2.io import gfile).

Ideally we should have ZERO entries as exemptions in pytest.ini. Investigate the root cause and fix them.

Non fully reproducible results on GPU

Although random key is fixed (e.g. jax.random.PRNGKey(0)), the results of different runs are always different.

My question is how one can fix the random behavior? Because my expectation is when I choose a fixed random key, all the runs should produce the same result.

Thank you in advance.

Use the following code to reproduce the issue (I simply take the MNIST example with shuffle removed):

import jax
import flax
import numpy as onp
import jax.numpy as jnp
import tensorflow_datasets as tfds

class CNN(flax.nn.Module):
  def apply(self, x):
    x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
    x = jax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = flax.nn.Dense(x, features=256)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=10)
    x = flax.nn.log_softmax(x)
    return x

@jax.vmap
def cross_entropy_loss(logits, label):
  return -logits[label]

def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}

@jax.jit
def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch['image'])
    loss = jnp.mean(cross_entropy_loss(
        logits, batch['label']))
    return loss, logits
  optimizer, _, _ = optimizer.optimize(loss_fn)
  return optimizer

@jax.jit
def eval(model, eval_ds):
  logits = model(eval_ds['image'] / 255.0)
  return compute_metrics(logits, eval_ds['label'])

def train():
  train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
  train_ds = train_ds.cache().batch(128)
  test_ds = tfds.as_numpy(tfds.load(
      'mnist', split=tfds.Split.TEST, batch_size=-1))

  _, model = CNN.create_by_shape(
      jax.random.PRNGKey(0),
      [((1, 28, 28, 1), jnp.float32)])

  optimizer = flax.optim.Momentum(
      learning_rate=0.1, beta=0.9).create(model)

  for epoch in range(10):
    for batch in tfds.as_numpy(train_ds):
      batch['image'] = batch['image'] / 255.0
      optimizer = train_step(optimizer, batch)

    metrics = eval(optimizer.target, test_ds)
    print('eval epoch: %d, loss: %.4f, accuracy: %.2f'
         % (epoch+1,
          metrics['loss'], metrics['accuracy'] * 100))

train()
train()

Language Model Notebook Broken

This is just a tracking bug report - the TPU language model colab notebook is currently broken due to a low-level change in jaxlib that's interacting badly w. the colab tpu backends. We should be able to fix it soon.

What am I missing?

uhmm...

432 stars....
featured on trending...
empty repository....
2 issues...

what am i missing?

Simplify the CIFAR10 example

Currently, the CIFAR10 example implements 2 architecture (Wide ResNet and PyramidNet), 2 regularization methods (Shake-shake and shake-drop), several learning rate schedules and allows training various combinations of these. This makes the example large and unnecessarily complex.

  • Can we just keep one or two of the above combinations?
  • Which ones are more relevant / useful / educational?
  • Should we turn some other combinations into HOW-TOs? Which ones?

Flattening parameters

Hi,

Great package! I like the syntax.

Is there an easy way to pack and unpack parameters for a flax.nn.Module? I saw this in the optimizer code:

    new_params = jax.tree_unflatten(treedef, new_params_flat)
    new_param_states = jax.tree_unflatten(treedef, new_states_flat)

as well as this:

  def init_state(self, params):
    sub_states = []
    for traversal, opt in zip(self.traversals, self.sub_optimizers):
      params_t = list(traversal.iterate(params))
      state = opt.init_state(params_t)
      sub_states.append(state)

but was wondering if there was a simpler way to convert the parameters to a single array (which I can still compute gradients on)? I'm interested in doing some operations on the parameters as a single vector as part of the loss function.

Thanks,
Miles

Example: DQN

I would like to add a DQN example, is there interest in it?
The idea is to use only flax and JAX, being independent of rlax

Improve documentation

Let's use this issue to collect all the places in which the documentation should be improved. If you think there's something missing from the list, please comment and we'll update the list.

As we work through this list, we'll update the description with links to PRs addressing each of the items.

  • nn.Collection
  • Module.apply
  • Module.call
  • Module should get a docstring
  • nn.stateful
  • nn.stochastic
  • Module.get_param (explain when it should be used)

A flax module without trainable parameter changes the rng

In a flax module, I turned a function that applies a bunch of numpy operations to a flax (sub-)module that has no trainable parameter, e.g.:

def add(x, y):
  return x + y

to

@nn.module
def add(x, y):
  return x + y

I noticed that I get different numbers at the end. Seems this is because the submodule changes the random number generators by splitting the rng of the parent module, while this behaviour is not expected in this case.
Maybe it makes more sense to condition splitting the rng on the existence of trainable parameters in the submodule?

FLIP: rm 'shape' from Module.param call signature

Goals

In our current implementation, initializer methods (which get passed to Module.param) are required to have the following signature:

def initializer(rng_key, shape):
  # ...
  return initialized_parameters

and in Module.param we assert that initialized_parameters.shape == shape.

Sometimes an initializer needs more (or less) information than the shape of its output, and at the moment this is achieved by writing the initializer function within a Module definition so that it can close over other data that it requires. For example, weightnorm initialization, which is data-dependent, can be implemented as follows, note the necessity of the dummy shape arguments:

class Conv2D(nn.Module):
  def apply(self, inputs, features, kernel_size):
    strides = (inputs.ndim - 2) * (1,)

    conv = partial(
        lax.conv_general_dilated, window_strides=strides, padding='VALID',
        dimension_numbers=('NHWC', 'HWIO', 'NHWC'))

    in_features = inputs.shape[-1]
    kernel_shape = kernel_size + (in_features, features)

    def initializer(key, shape):
      # A weightnorm initializer generating a (direction, scale, bias) tuple.
      # Note that the shape argument is not used.
      direction = nn.initializers.normal()(key, kernel_shape)
      unnormed_out = conv(inputs, _l2_normalize(direction))
      mean = np.mean(unnormed_out, (0, 1, 2))
      var  = np.std (unnormed_out, (0, 1, 2))
      return dict(direction=direction, scale=1 / var, bias=-mean / var)

    # We feed in None as a dummy shape argument to self.param.  Currently
    # Module.params assumes that the initializer takes in a shape argument;
    # None acts as a flag to avoid shape checking.
    params = self.param('weightnorm_params', None, initializer)
    direction, scale, bias = [params[k] for k in ('direction', 'scale', 'bias')]
    return conv(inputs, _make_kernel(direction, scale)) + bias

This situation isn't terrible, but it does highlight the fact that the assumption that initializers depend on parameter shape and nothing else is a bit arbitrary.

A more flexible API, with initializer requiring only a JAX PRNG key, would mean more consistent implementations of different types of initializers, and might also help to correctly emphasize what Flax's role is here, namely to handle splitting and passing of PRNG keys to initializers and to setup the parameters data-structure (a nested dictionary).

Proposal

We propose to change the call signature of Module.param from

param(self, name: str, shape: Shape, initializer: Callable[[Key, Shape], Array]):

to

param(self, name: str, initializer: Callable[[Key], Array]):

This change would lead to a slight simplification of the weightnorm example above, since the dummy shape arguments could be removed.

For existing Modules for which the initializer is a straightforward function of the parameter shape and no other data is required, we can alter the currying of the initializer definitions so that lines like

kernel = self.param('kernel', kernel_shape, kernel_init)

can be replaced by

kernel = self.param('kernel', kernel_init(kernel_shape))

There may be downsides to this approach which I, being a relative Flax noob, am unaware of. One thing is that we'd lose the shape checking in Module.params, but that seems like the kind of check which should be part of a test anyway.

Alternatives

I think the obvious alternative is to simply keep the current API. The change proposed above is relatively minor but it still would likely require a number of users to make changes to their own code.

FLIP: Make module instances semantically meaningful by not overriding `module.__new__`

Introduction

Currently, while Flax modules are defined by subclassing flax.nn.Module, those modules don't behave the same way that normal Python objects behave.

One of the large differences is that Flax Modules override __new__, meaning that module instances aren't a semantically meaningful thing in Flax at the moment. Right now, in Flax, what looks like module construction (nn.Dense(x, features=10)) actually does two things:

  1. Construct an object of type nn.Dense (using the non-documented API module.new_instance())
  2. Call the apply method on that instance and return it.

Some upsides of the current approach are:

  1. Modules are defined as a single function, as opposed to, e.g. the style of other libraries, such as Haiku, where you need to scroll up and down between __init__ and __call__ to fully understand what a module does.
  2. Calls to submodules are very concise, e.g. nn.Dense(x, features=10).

Some downsides of the current approach are:

  1. In order to reuse a module, you must use the module.shared() abstraction which has a confusing mental model -- what does module.shared() return? A module class? A module instance? Moreover, which arguments must be passed into module.shared() in order for the shared module to be usable? (Behind the scenes shared is implemented on top of partial)
  2. You can't instantiate a module directly, outside of another module. This leads to surprising things like new nn.Model(nn.Dense.partial(features=10), params) -- why do we need to use partial to instantiate a Model? What type does the first argument to nn.Model have? Is it a module class? Module instance?
  3. In a few spots in flax/nn/base.py there is code that does "kwarg mangling". Some of these code had bugs before. It would be nice to reduce the need for kwarg mangling.
  4. In order to support multiple methods on a module, the module_method decorator turns methods that aren't apply into new Modules. This is surprising, for example how would I do the equivalent of module.call(params, *args) but to call a method foo that's not apply? That would be module.foo.call(params, *args). That's a pretty surprising mental model.
  5. Wanting to define shared parameters or submodules that work across multiple methods requires either using non-traditional patterns and/or with more complexity in Flax core (see discussion on #161)
  6. apply was a special-cased method on modules.

Proposal

  1. No longer override __new__ in Modules
  2. Eliminate .partial()
  3. Potentially eliminate .shared() (though we may choose to keep it as a safeguard -- see below)
  4. Split up current module's apply methods into the controlled use of Python 3.7 dataclasses (for storing module hyperparameters) and a "vanilla Python" __call__ method (or actually, any name you want) that only takes in the module input(s)
  5. This may even allow for module instance to directly refer to a read-only version of their parameters, e.g.:
class Foo(Module):
  def __init__(x):
    dense = nn.Dense(features=16)
    x = dense(x)
    # `dense.params` is defined here; maybe also `dense.params.kernel` and `dense.params.bias`

For example, a simple Dense layer may look like this:

@dataclass
class Dense(Module):
  features: int
  kernel_init: Callable = initializers.lecun_normal()
  bias_init: Callable = initializers.zeros

  def __call__(self, x):
    """Applies a linear transformation to the inputs along the last dimension."""
    kernel = self.param('kernel', (x.shape[-1], self.features), self.kernel_init)
    bias = self.param('bias', (self.features,), self.bias_init)
    return jnp.dot(x, kernel) + bias

Then, an MLP would look like this:

class MLP(Module):
  def __call__(self, x):
    x = nn.Dense(features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(features=16)(x)

I believe that this proposals keeps the conciseness of current Flax, while having the potential to significantly reduce both implementation complexity and mental model complexity. The mental model in Flax now reduces to the same one as Keras (other than the fact that parameters are immutable)

For example, in this case re-using a module is trivial -- keep a reference to nn.Dense(features=16) and re-use that. (NOTE: We may choose to keep the safe-guarding behavior of .shared() that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when __call__ is invoked a second time, unless .shared() was called on the module instance first)

With this proposal, there's also no need for module.partial -- you can just use functools.partial(module.__call__) or functools.partial(module). (Though this is a bit different than in current Flax because the return value of functools.partial in itself isn't a module, rather it's a function. But maybe it was always confusing to understand module.partial -- does it override kwargs for all module methods? Just apply?)

Possible transition plan

Given the non-trivial amount of code written using Flax, and the fact that this proposal would change every module written with Flax, we need an upgrade plan.

I propose adding, alongside every new module in flax.nn, a function with the same name but lower-cased, that operates the same as in current Flax. These functions would be deprecated-on-arrival. E.g., alongside Dense as shown above we would also have

def dense(x, features, kernel_init, bias_init):
  """DEPRECATED. Use the new Module API: http://link/to/upgrade/guide."""
  return Dense(features, kernel_init, bias_init)(x)

Then the first part of the upgrade process is mainly search and replace "Dense" -> "dense", etc.. In addition, some more manual changes will possible be necessary for uses of .partial and .shared. Later, users can transition into the new API at a time they see fit.

Current State

@avital has a messy work-in-progress branch checking the viability of using dataclasses in this settings. Results so far seem cautiously promising, but more work is needed before this proposal is ready to be acted on.

Full end-to-end MNIST example gives error.

I copied and pasted Full end-to-end MNIST example code then run train() and I got error.

Colab environment packages versions:

  • TensorFlow: '2.1.0'
  • Jax: '0.1.58'
  • Numpy: '1.17.5'

Error:

AttributeError                            Traceback (most recent call last)
<ipython-input-4-2da0ffaf5447> in <module>()
----> 1 train()

12 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in __getattr__(self, name)
    363 
    364     try:
--> 365       attr = getattr(self.aval, name)
    366     except KeyError:
    367       raise AttributeError(

AttributeError: 'ShapedArray' object has no attribute '_trace'

Example: pix2pix

I want to add pix2pix example. We can discuss any question, issue, etc. on here. I already have some work about pix2pix as a draft pr.

Related issue: #192
Related pr: #186

FLIP: Support __init__ in Modules

Introduction

By default Modules are defined using only an apply function and unshared parameters and submodules. This makes it easy to write modules (with control flow) in a concise and readable way.

However, some modules don't fit well in this abstraction. For example consider an autoencoder. During model training we would like to take an observed example and encode and decode it. However, we would also like to be able to call the encode and decode procedures as separate methods for other use cases besides training.

With the current Flax api a simple AutoEncoder can be written as follows:

from flax import nn
class AutoEncoder(nn.Module):
  
  def _create_modules(self, encoder_features, decoder_features):
    encoder = nn.Dense(features=encoder_features, name='encoder')
    decoder = nn.Dense(features=decoder_features, name='decoder')
    return encoder, decoder

  def apply(self, x, **hparams):
    encoder, decoder = self._create_modules(**hparams)
    z = encoder(x)
    return decoder(z)

  @nn.module_method
  def encode(self, x, **hparams):
    encoder, _ = self._create_modules(**hparams)
    return encoder(x) 

  @nn.module_method
  def decode(self, z, **hparams):
    _, decoder = self._create_modules(**hparams)
    return decoder(z) 

A number of issues can be noticed in this examples:

  1. hyper parameters need to be passed around manually from all module methods
  2. _create_modules behaves a lot like a constructor but also needs to be called manually
  3. we cannot directly call the module methods encode and decode from apply leading to even more code duplication

Proposal

We would like to improve the syntax of modules that require multi methods and reuse of parameters, sub modules, and hyperparameters across various methods.

The proposed syntax allows us to rewrite the AutoEncoder module as follows

class AutoEncoder(nn.Module):

  def setup(self, encoder_features, decoder_features, **kwargs):
    self.encoder = nn.Dense.shared(features=encoder_features, name='encoder')
    self.decoder = nn.Dense.shared(features=decoder_features, name='decoder')
    return kwargs

  def apply(self, x):
    z = self.encode(x)
    return self.decode(z)

  @nn.module_method
  def encode(self, x):
    return self.encoder(x)a 

  @nn.module_method
  def decode(self, x):
    return self.decoder(x)

model_def = AutoEncoder.partial(encoder_features=4, decoder_features=8)
_, params = model_def.init(rng, x)
model = nn.Model(model_def, params)
# use apply function for training
x_recon = model(x)
# two step encode+decode
z = model.encode(x)
x_recon = model.decode(z)

A few differences w.r.t. to the introduction example:

  1. a constructor (setup) defines shared modules and assigns them to fields.
  2. the constructor defines the hyperparameters and they are no longer passed around by other methods.
  3. apply reuses the module methods avoid code duplication.

A few changes are required to make the new syntax work

  1. When a Module is constructed we must first call the setup function. The setup function receives all kwargs and returns the remaining keyword arguments that should be passed to the module method.

  2. when calling a module_method using self.some_module_method(...) it behaves as an ordinary python method.

An implementation of this proposal lives in draft PR #104

Alternatives

The main issue in this proposal is determining which arguments are passed to setup. There are a few variations that can be considered:

  1. Introspection is used to determine which keyword arguments belong to setup.

  2. Require users to provide a list of construction arguments

  3. Pass all keyword argument to setup. This woud make the implementation easier but would require most module methods to include something like **unused_kwargs to work correctly.

  4. [CURRENT PROPOSAL] setup receives all keyword arguments and returns a dictionary of keyword arguments that should be passed to the apply method and other module methods

[Question] Parameters

Hi, I am curious how Flax deals with parameters given you see code like this

def apply(self, x):
   x = Conv(x, ...)

where a new instance of the module Conv is created every time during apply. Even more curious about how parameter sharing will be approached. I think it will be important to explain this because all pytorch/tf.keras users would expect these layer to be instantiated during __init__ and use in apply.

FLIP: dtype API

Goals

The high level goal of this proposal is to provide a consistent api for dealing with precision of parameters and computation for the various floating point types currently available in Jax.

The default precision is currently float32.
CPUs & GPUs also support double precision (float64).
Half precision types are more complicated because GPUs and TPUs support float16 and bfloat16 respectively. Both have reduced precision and float16 additionally has a reduced range compared to float32.

  1. Built in modules should respect the dtype of its inputs
  2. Avoid automatically using half precision for computations with known numerical instability
  3. It should be possible to control the dtype of parameters separately from the dtype of computations

Proposal

To determine the dtype of the computation we use the dtype of the most precise input which we call the input dtype.

  1. initializers default to float32 precision but Modules should work with half or double precision too. Users can pass initializers that return a ndarray with (b)float16 or float64 dtypes. Similairly, users can decide to cast all parameters to a different dtype after initialization.

  2. if the input dtype is float64 all computation is done in float64.

  3. if the input dtype is float32 all computation is done in float32. All Modules are currently implemented and tested using float32 as the default. No surpises here.

  4. if the input dtype is (b)float16 all outputs are stored in (b)float16 and matrix computations (conv & dot) are performed in (b)float16. The garantuees are intentionally minimal here. We only guarantee matrix computations to run at half precision. This fits in well with how the current hardware accelerators look like. Both modern GPUs and TPUs have special hardware for accelerating matrix (like) computations and the computational intensity of matrix multiples makes them compute bound such that we can effectively use the increased flops. Storing (intermediate) outputs opens up the secondary benefit of these half precision types which is reduced memory consumption.

Alternatives

  1. use a dtype keyword argument (current api). The downside of this is that it is pretty verbose and might suggest to the user that we will stricly adhere to doing all computations with the given dtype. It is also easy to accidentally forget passing the dtype to a submodule which results in a silent error.

  2. add dtype arguments for parameters. This could end up being very verbose and error prone just like (1). Another downside is that the current jax random ops don't have good support non-default dtypes often resulting in errors.

  3. execute all computation in given dtype irrespective of numerical stability. The benefit of this approach is that it results in more explicit and simpler APIs with the big downside that the Modules library will support combinations of Modules and dtypes that are likely to result in numerical unstability. This could also lead to premature optimization given that many computations on modern accelerators are bottlenecked by memory not compute so low precision compute does not automatically lead to increased performance.

Create near state of the art object detection example

An object detection example would be extremely useful. A very good single-stage detection model that runs on Microsoft COCO would be an extremely useful starting point for flax users who need to do detection.

The single-stage detection models get pretty good results these days and should be simpler to implement, but it would be good if someone who is familiar with the detection literature could comment on this issue and suggest something that isn't too far from the state of the art (at least as good or better than RetinaNet).

[Question] Best way to implement stochastic reinforcement learning actors?

Hello!

First of all, thanks for this project - it is a lifesaver! So I wanted to get familiar with JAX so I decided to implement a few deep reinforcement learning algorithms as a side project. I initially approached the problem by subclassing flax.nn.Module as follows:

import flax.nn as nn
import jax.numpy as jnp
import jax.random as random

class MLPCategoricalActor(nn.Module):
    def apply(self, obs, act, action_space=None, rng=None,
              hidden_sizes=(64, 64), activation_fn=nn.tanh, output_fn=None):
        assert action_space is not None, "Action space must be specified."
        if rng is None:
            rng = nn.make_rng()
        act_dim = action_space.n
        logits = _MLP(obs, sizes=list(hidden_sizes) + [act_dim], activation_fn=activation_fn)
        pi = random.categorical(rng, logits)
        logp_all = nn.log_softmax(logits)
        logp = jnp.multiply(one_hot(act, act_dim), logp_all).sum(axis=1)
        logp_pi = jnp.multiply(one_hot(pi, act_dim), logp_all).sum(axis=1)
        return pi, logp, logp_pi

I very quickly ran into the problem of having a duplicate parameter 'rng'. I dug into the flax code and discovered that 1) dropout was not implemented as a module as I believed and 2) the ModuleFrame has an rng param that I can't access (apparently). I came up with three solutions:

  1. Get rng from the nn.stochastic context, but that would require wrapping the entire training function with it which seems a little weird to me.

  2. Use the same solution as in the VAE example and pass rng each time as a positional argument.

  3. Mix both solutions and try to get rng from a kwarg and if that fails fallback to the context. This may lead to a problem if someone sets the kwarg with a call to partial...

I wanted to ask you how you would go about this? My main concerns are code reusability and reproducibility.

Use module_method for RNNCellBase.initialize_carry

The RNNCellBase interface requires passing consistent arguments to both initialize_carry and apply. if the caller isn't careful to pass the same args (e.g. cell_size) to both methods, things will immediately break. If these hyperparams were instead specified only once, users couldn't make that mistake.

(This isn't an issue for the existing cells because they infer cell sizes from the carry argument, but this issue can come up for more complex cells.)

If initialize_carry was defined as a module_method, shared hyperparameters could be applied using .partial()

This would also make it possible for an RNN to learn its initial state.

`Module.partial().__qualname__` is not copied

I don't know if this is expected but Module.partial do not seem to forward __qualname__.

To reproduce:

class MyModule(flax.nn.Module):
  def apply():
    pass

print(MyModule.partial().__name__)  # MyModule
print(MyModule.partial().__qualname__)  # Module.partial.<locals>.PartialModule

Incorrext init for LSTM

The code below appears incorrect - I may be misunderstanding though. As per the comment you're summing 2 dense layers so one has no bias - but dense_h has bias=False and you're passing the bias_init, dense_i has bias=True but is using the default bias_init? Should be the other way around or pass bias_init to both?

Since the default bias_init is zeros for both Dense.apply() and LSTMCell.apply() it has no impact unless a different bias_init is supplied?

https://github.com/google-research/flax/blob/e7247d58e4f3460c03da5f935cb83d9c0883a97c/flax/nn/recurrent.py#L97-L103

model and optimizer.target diff

I am working on pix2pix example. I don't get an error when I give the picture to the model, but when I give the picture to model_optimizer.target I am getting error. Why?

Code

import jax
import flax

import numpy as onp
import jax.numpy as jnp


OUTPUT_CHANNELS = 3

class DownSample(flax.nn.Module):
  def apply(self, x, features, size, apply_batchnorm=True):
    x = flax.nn.Conv(x, features=features, kernel_size=(size, size), strides=(2, 2), padding='SAME', bias=False)
    if apply_batchnorm:
      x = flax.nn.BatchNorm(x)
    x = flax.nn.leaky_relu(x)
    return x

class UpSample(flax.nn.Module):
  def apply(self, x, features, size, apply_dropout=True):
    x = flax.nn.ConvTranspose(x, features=features, kernel_size=(size, size), strides=(2, 2), padding='SAME', bias=False)
    x = flax.nn.BatchNorm(x)
    if apply_dropout:
      x = flax.nn.dropout(x, 0.5)
    x = flax.nn.relu(x)
    return x

down_list = [[64, 4, False],
             [128, 4],
             [256, 4],
             [512, 4],
             [512, 4],
             [512, 4],
             [512, 4],
             [512, 4]]

up_list = [[512, 4, True],
           [512, 4, True],
           [512, 4, True],
           [512, 4],
           [256, 4],
           [128, 4],
           [64, 4]]

class Generator(flax.nn.Module):
  def apply(self, x):
    skips = []
    for down in down_list:
      x = DownSample(x, *down)
      skips.append(x)
    
    skips = list(reversed(skips[:-1]))
    for up, skip in zip(up_list, skips):
      x = UpSample(x, *up)
      x = jnp.concatenate((x,skip))
    
    x = flax.nn.ConvTranspose(x, features=OUTPUT_CHANNELS, kernel_size=(4,4), strides=(2,2), padding='SAME')
    x = flax.nn.tanh(x)
    return x

def create_model(key, batch_size, image_size, model_def):
  input_shape = (batch_size, image_size, image_size, 3)
  with flax.nn.stateful() as init_state:
    with flax.nn.stochastic(jax.random.PRNGKey(0)):
      _, initial_params = model_def.init_by_shape(key, [(input_shape, jnp.float32)])
      model = flax.nn.Model(model_def, initial_params)
  return model, init_state

def create_optimizer(model, learning_rate, beta):
  optimizer_def = flax.optim.Adam(learning_rate=learning_rate,
                                 beta1=beta)
  optimizer = optimizer_def.create(model)
  optimizer = flax.jax_utils.replicate(optimizer)
  return optimizer

key = jax.random.PRNGKey(0)
generator_model, generator_state = create_model(key, 1, 256, Generator)
generator_optimizer = create_optimizer(generator_model, 2e-4, 0.5)

test_input = jax.random.normal(jax.random.PRNGKey(1), (1, 256, 256, 3))
with flax.nn.stochastic(jax.random.PRNGKey(0)):
  prediction = generator_model(test_input)  # work with no error
  print('prediction ok')
with flax.nn.stochastic(jax.random.PRNGKey(0)):
  prediction_opt = generator_optimizer.target(test_input)

Error

ValueError: Existing shape (1, 4, 4, 3, 64) differs from requested shape (4, 4, 3, 64)

Related pr: #186

Thank you!

Error when JITting `Model.__call__`

eg

import jax
from flax import nn
layer=nn.Dense.partial(features=1)
key=jax.random.PRNGKey(0)
x=jax.random.normal(key, (20, 2))
_,params=layer.init(key, x)
layer_m=nn.Model(layer, params)
jax.jit(layer_m)(x)

errors with

TypeError                                 Traceback (most recent call last)
<ipython-input-2-2e4e0581e3f5> in <module>
      6 _,params=layer.init(key, x[0,...])
      7 layer_m=nn.Model(layer, params)
----> 8 jax.jit(layer_m)(x)

~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    148     flat_fun, out_tree = flatten_fun(f, in_tree)
    149     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 150                        name=flat_fun.__name__)
    151     return tree_unflatten(out_tree(), out)
    152

~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in __name__(self)
    121   @property
    122   def __name__(self):
--> 123     return getattr(self.f, '__name__', '<unnamed wrapped function>')
    124
    125   def wrap(self, gen, gen_static_args, out_store) -> 'WrappedFun':

~/opt/anaconda3/lib/python3.7/site-packages/flax/nn/base.py in __getattr__(self, name)
    897   def __getattr__(self, name):
    898     value = getattr(self.module, name)
--> 899     if issubclass(value, Module):
    900       def wrapper(*args, **kwargs):
    901         return value.call(self.params, *args, **kwargs)

~/opt/anaconda3/lib/python3.7/abc.py in __subclasscheck__(cls, subclass)
    141         def __subclasscheck__(cls, subclass):
    142             """Override for issubclass(subclass, cls)."""
--> 143             return _abc_subclasscheck(cls, subclass)
    144
    145         def _dump_registry(cls, file=None):

TypeError: issubclass() arg 1 must be a class

Make `ModelParamTraversal` more public?

ModelParamTraversal is currently somewhat hidden within optim. But it is much more generally useful, for example for implementing weight-decay (not as a loss) or weight standardization or spectral norm (I think).
So it seems like putting it in traverse_util.py (where I'd look for it) would make sense.

TransformerLM lm1b model accept float and out-of vocabulary token as input

I was experimenting with the TransformerLM from the lm1b example (https://github.com/google/flax/blob/master/examples/lm1b/models.py).

I noticed that TransformerLM do not raise any error when:

  • The input is float (e.g. np.float32)
  • The input ids are outside the range [0, vocab_size]

Both of those seems confusing to me. Is it the expected behavior ?

Example:

from flax.examples.lm1b import models as lm1b_models

model_cls = lm1b_models.TransformerLM.partial(vocab_size=32)

y, params = model_cls.init_by_shape(
  jax.random.PRNGKey(0),
  [((1, 3), jnp.float32)],  # < Tracing with float don't raise error
)
model = flax.nn.Model(model_cls, params)

model(jnp.array([[0.3, 0.5, 2.9]]))  # < Float don't raise error
model(jnp.array([[100, 101, 102]]))  # < Out of vocabulary token ids don't raise error

As a side note, I noticed that the input pipeline is using the deprecated TFDS sub-split API: tfds.Split.TRAIN.subsplit which is incompatible with the last versions of TFDS (split='train[90%:]' should be used instead).

(minor cleanup) I also think

dataset_builder = tfds.builder(dataset_name, data_dir=data_dir)
...
train = tfds.load(dataset_name, data_dir=data_dir, split='train')
valid = tfds.load(dataset_name, data_dir=data_dir, split='test')

could be rewritten as:

builder = tfds.builder(dataset_name, data_dir=data_dir)
builder.download_and_prepare()  # No-op if data already exists
...
train = builder.as_dataset(split='train')
valid = builder.as_dataset(split='test')

This would avoid reloading three time the dataset, which may save a few seconds startup when loading from a remote file systems.

"Manual" updates to module parameters

I'm trying to implement a projected gradient descent method and need to apply a constraint to parameters of a nn.Module. For example, apply jnp.abs() to a specific parameter. I haven't been able to find a simple way to do this yet, let me know the there's a canonical approach.

Documentation for recurrent

I'm studying RNN's using jax so I'm currently investigating flax. I think the documentation in the RNN module is incorrect or out of date.

https://github.com/google-research/flax/blob/e7247d58e4f3460c03da5f935cb83d9c0883a97c/flax/nn/recurrent.py#L21-L23

Results in TypeError: apply() missing 1 required positional argument: 'inputs'

Also create builds and evaluates the model and returns a (y, model), so I feel like the design has changed and the recurrent examples should either initialise the state before calling create (but that wouldn't scan), or call create_by_shape?

Edit: I found a test which seems to confirm that the docstring is incorrect i.e. the code below creates an initial carry and passes to create

https://github.com/google-research/flax/blob/e7247d58e4f3460c03da5f935cb83d9c0883a97c/tests/nn_test.py#L461-L468

2nd Edit:

Also, I'm slightly confused by LSTMCell.initialize_carry() - it requires a batch_dim, and returns an initialised (zero) state for each batch. I might be missing something but this seems to preclude using lax.scan() to process each batch sequentially using the state from the previous batch as the initial state for the next batch (or some other state estimator which is specifically what I'm attempting). For example I have 365 "trajectories" (timeseries) each consisting of 24 samples and 5 features. So the state should be size 5 and I want to scan each trajectory from some initial state I provide (the intent is to use another net to estimate the state).

init_by_shape does not work on pytrees

init_by_shape only supports a list of arrays as lazy arguments.
Instead it would be better to support arbitrary pytrees.

The easiest way to support this is by using the ShapeDtypeStruct in Jax similar to jax.eval_shape.

apply_gradient with no parameters gives ValueError

This issue is admittedly a corner case, but one we've run into. If we consider the following flax.nn.Module:

class Identity(flax.nn.Module):
    def apply(self, x):
        return x

We won't be able to call apply_gradient since the output from this line will be an empty list.

This should probably (?) be addressed since it's exceptional behavior that may surprise, but could see arguments for different ways of resolving. One simple answer is to just no-op, but there might be some higher-level concerns I'm not thinking about which say we don't even want parameterless modules (in which case, raise on construction).

Anyway, we've resolved for now by just adding a dummy parameter. Here's the full minimum example and the resulting value error:

import flax
import jax
import jax.numpy as jnp

class Identity(flax.nn.Module):
    def apply(self, x):
        return x

model_def = Identity.partial()
_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1,)])
model = flax.nn.Model(model_def, params)

def loss_fn(model, x, y):
    y_hat = model(x)
    return jnp.square(y - y_hat).mean(), y_hat

optim_def = flax.optim.Adam(learning_rate=1.0)
optimizer = optim_def.create(model)

(loss, y_hat), grad = jax.value_and_grad(loss_fn, has_aux=True)(optimizer.target, 1.0, 2.0)
optimizer.apply_gradient(grad)
~/src/flax/flax/optim/base.py in apply_gradient(self, hyper_params, params, state, grads)
    135            for param, state, grad in zip(params_flat, states_flat, grads_flat)]
    136 
--> 137     new_params_flat, new_states_flat = list(zip(*out))
    138     new_params = jax.tree_unflatten(treedef, new_params_flat)
    139     new_param_states = jax.tree_unflatten(treedef, new_states_flat)

ValueError: not enough values to unpack (expected 2, got 0)

Possible idea for Sequential implementation?

Wasn't sure if there was some philosophical reason not to have Sequential (many of these points might apply!), but I found having this simple abstraction useful. It's...not the prettiest, I admit.

class Sequential(nn.Module):
    def apply(self, x, modules, args):
        results = x
        for module, arg in zip(modules, args):
            results = module(results, **arg)
        return results

And the way you might use is

model_def = Sequential.partial(modules=[Identity, Plus], args=[{}, {"z": 2}])
ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 2)])
model = nn.Model(model_def, params)
model(np.array([1,2]))

> DeviceArray([3, 4], dtype=int32)

As a complete example with the dummy modules:

class Identity(nn.Module):
    def apply(self, x):
        return x

class Plus(nn.Module):
    def apply(self, x, z):
        return x + z

class Sequential(nn.Module):
    def apply(self, x, modules, args):
        results = x
        for module, arg in zip(modules, args):
            results = module(results, **arg)
        return results

model_def = Sequential.partial(modules=[Identity, Plus], args=[{}, {"z": 2}])
ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 2)])
model = nn.Model(model_def, params)
model(np.array([1,2]))

Re-shuffle data each iteration

Some example (e.g. nlp_seq) show clear patterns in their training loss (see TensorBoard).

This is probably because the training data isn't shuffled enough. We should consider adding shuffle with reshuffle_each_iteration=True to the train IO pipelines of existing examples.

Optimize multiple models at once

Curious how I can use Flax to optimize two models with the same loss function.

I have an encoder-decoder setup where I'm using teacher forcing to train a decoder LSTM. This requires doing one encoding step of an input, then a lot of incremental decoding steps where loss is added after each one. I plan on doing this by having a distinct encoder and decoder network. Any quick intuitions for how I can get Flax to optimize both simultaneously? It looks like optimizers only take in a single object to optimize.

Thanks!

Fully deprecate `optimizer.compute_gradient`

We're replacing the use of optimizer.compute_gradient with calls to JAX's jax.grad or jax.value_and_grad.

  • The docstring for compute_gradient should specify that it's deprecated and print a warning, like for optimizer.optimize().
  • The ImageNet example shouldn't use compute_gradient

[Example] GraphSage with cora|citeseer|pubmed

Currently the GNN model uses a GCN model trained on the Zachary's karate club dataset. However, both the model and the dataset are simplified, e.g., GCN is the most basic GNN layer and the dataset used can be simply represented as an edge_list.

I'd like to write a GraphSage model with one of the titled datasets to concrete the examples under the GNN section.

BatchNorm error message

The flax.nn.normalization.BatchNorm error message "batch_stats should be provided if use_running_averages is True" is slightly misleading.

I think would be clearer if it say something like "when use_running_averages is True either use a stateful context or provide batch_stats"

It wasn't clear to me why the imagenet example worked without providing batch_stats and my layer failed until I looked closer at the source and realised that it was using a stateful context.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.