Overview | Quick install | What does Flax look like? | Documentation
๐ฃ NEW: Check out the NNX API!
This README is a very short intro. To learn everything you need to know about Flax, refer to our full documentation.
Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.
Flax is being used by a growing community of hundreds of folks in various Alphabet research departments for their daily work, as well as a growing community of open source projects.
The Flax team's mission is to serve the growing JAX neural network research ecosystem -- both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads. We are in the process of moving some remaining internal design docs and conversation threads to GitHub discussions, issues and pull requests. We hope to increasingly engage with the needs and clarifications of the broader ecosystem. Please let us know how we can help!
Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!
We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.
In case you want to reach us directly, we're at [email protected].
Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.
Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:
-
Neural network API (
flax.linen
): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout -
Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device
-
Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging
-
Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b
You will need Python 3.6 or later, and a working JAX installation (with or without GPU support - refer to the instructions). For a CPU-only version of JAX:
pip install --upgrade pip # To support manylinux2010 wheels.
pip install --upgrade jax jaxlib # CPU-only
Then, install Flax from PyPi:
pip install flax
To upgrade to the latest version of Flax, you can use:
pip install --upgrade git+https://github.com/google/flax.git
To install some additional dependencies (like matplotlib
) that are required but not included
by some dependencies, you can use:
pip install "flax[all]"
We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.
To learn more about the Module
abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our
guides and
developer notes.
from typing import Sequence
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
model = CNN()
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class AutoEncoder(nn.Module):
encoder_widths: Sequence[int]
decoder_widths: Sequence[int]
input_shape: Sequence[int]
def setup(self):
input_dim = np.prod(self.input_shape)
self.encoder = MLP(self.encoder_widths)
self.decoder = MLP(self.decoder_widths + (input_dim,))
def __call__(self, x):
return self.decode(self.encode(x))
def encode(self, x):
assert x.shape[1:] == self.input_shape
return self.encoder(jnp.reshape(x, (x.shape[0], -1)))
def decode(self, z):
z = self.decoder(z)
x = nn.sigmoid(z)
x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
return x
model = AutoEncoder(encoder_widths=[20, 10, 5],
decoder_widths=[5, 10, 20],
input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)
In-detail examples to train and evaluate a variety of Flax models for Natural Language Processing, Computer Vision, and Speech Recognition are actively maintained in the ๐ค Transformers repository.
As of October 2021, the 19 most-used Transformer architectures are supported in Flax and over 5000 pretrained checkpoints in Flax have been uploaded to the ๐ค Hub.
To cite this repository:
@software{flax2020github,
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.9.0},
year = {2024},
}
In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.
Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.
flax's People
Forkers
britefury tuannguyen27 vanamsterdam codeaudit amrmkayid databill86 alex-schaaf allprod brettkoonce ashwathaithal anutkk floscha adn6868 makora9143 shubhamshaswat ngk123 czzlegend mzkaramat basveeling bnwebcode joaogui1 jbcdnr razcle david-waterworth avital hephaex jakesnell gnecula jeroenvlek ronw halhenke mostafadehghani raymondyeh07 danieljtait littleggghost h3lio5 bohnetbd skyy93 adarob shoyer hawkinsp j-towns malmaud jheek danielsuo levskaya mfuntowicz nguyenducnhaty 321hg mmargenot vballoli shiyi001 laksh9950 mohitreddy1996 ethansdyer lucasb-eyer hyejikim1 zhang-yd15 chitwansaharia gan3sh500 yanii sdwivedi delanln rolandgvc xushiwei yanndupis kshithijiyer baschdl ayush-1506 jiawen dangraur wrzadkow bharatr21 fagan2888 salimmj andsteing marcvanzee juliakreutzer arnoutdevos skye saurabhdash yangliuy puneetmadaan omantere ranapop ludgerpaehler murugan-project jpuigcerver backpropper cgarciae biswajeetmishra143 wdevazelhes gordonrust rikuturkki mbz isabella232 myagues neotim ramasesh freddyaboultonflax's Issues
Polyak Averaging Params_ema initiailzation
Hi, I was applying the changes in your HOWTO to add Polyak averaging and there seems to be some code missing, specifically params_ema is not initialized and so the line
optimizer, params_ema = train_step(optimizer, params_ema, batch)
causes an UnboundLocalError (params_ema referenced before assignment).
TypeError: iteration over a 0-d array
When I run the following code, I get a TypeError: iteration over a 0-d array
error from Jax. The code looks correct to me, and I don't understand where this error is coming from or how to fix the issue.
import jax
import jax.numpy as jnp
from flax import nn, optim
class CNN(nn.Module):
def apply(self, x):
x = nn.Conv(x, features=32, kernel_size=(3, 3))
x = x.reshape((x.shape[0], -1))
v = nn.Dense(x, features=1)
return v
@jax.jit
def train_step(optimizer, observations, returns):
def loss_fn(model):
values_pred = model(observations)
return jnp.square(values_pred - returns).mean()
optimizer, _ = optimizer.optimize(loss_fn)
return optimizer
batch_size = 7
input_shape = (batch_size, 32, 32, 5)
key = jax.random.PRNGKey(0)
_, model = CNN.create_by_shape(key, [(input_shape, jnp.float32)])
optimizer = optim.Adam(learning_rate=0.01).create(model)
observations = jax.random.normal(key, input_shape)
returns = jax.random.normal(key, (batch_size, ))
optimizer = train_step(optimizer, observations, returns)
produces the following stack trace
Traceback (most recent call last):
File "test.py", line 33, in <module>
optimizer = train_step(optimizer, observations, returns)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
name=flat_fun.__name__)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py", line 605, in call_bind
outs = primitive.impl(f, *args, **params)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 449, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 223, in memoized_fun
ans = call(fun, *args)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 466, in _xla_callable
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 152, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "test.py", line 18, in train_step
optimizer, l, logits = optimizer.optimize(loss_fn)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/optim.py", line 266, in optimize
loss, aux, grad = self.compute_gradients(loss_fn)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/optim.py", line 250, in compute_gradients
(loss, aux), grad = grad_fn(self.target)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 413, in value_and_grad_f
ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 1293, in vjp
out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/ad.py", line 111, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/ad.py", line 98, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 337, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 156, in call_wrapped
ans = gen.send(ans)
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api_util.py", line 72, in flatten_fun_nokwargs2
ans, aux = yield py_args, {}
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py", line 300, in __iter__
return iter(self.aval._iter(self))
File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/lax/lax.py", line 1497, in _iter
raise TypeError("iteration over a 0-d array") # same as numpy error
TypeError: iteration over a 0-d array
Any help is greatly appreciated!
Reorganizing optim into directory structure?
Hello again! At the Princeton office, we work on, among other things, optimization algorithms for deep learning. We're interested in using flax
and wanted to add some other well-known algorithms. Would you guys be open to reorganizing optim.py
into a directory a la pytorch? Happy to submit a PR if so!
Usually, this would accompany a PR, but being new around here, wanted to understand how (if at all) you wanted to reorganize.
One possibility: All subclasses of OptimizerDef
(except MultiOptimizer
, which appears to have a circular dependency with OptimizerDef
) live in their own files (e.g., Momentum
, GradientDescent
)
Cannot run seq2seq train example :(
Hi!
First, thank you very much for this project!!!
Second, I am just simply trying to run this script to learn how to use GRU/LSTM: https://github.com/google/flax/blob/prerelease/examples/seq2seq/train.py
and I have encountered 2 issues:
-
batch_metrics is missing : Although is not an important issue, I just commented it out, but maybe you would like to do something about it :)
-
Error that is out of my skills:
Traceback (most recent call last):
File "/.../FlaxExample.py", line 288, in <module>
app.run(main)
File "/.../anaconda3/lib/python3.7/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/.../anaconda3/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "/.../FlaxExample.py", line 284, in main
_ = train_model()
File "/.../FlaxExample.py", line 271, in train_model
model = create_model()
File "/.../FlaxExample.py", line 174, in create_model
((1, get_max_output_len(), vocab_size), jnp.float32)])
File "/.../anaconda3/lib/python3.7/site-packages/flax/nn/base.py", line 183, in wrapper
return super_fn(*args, **kwargs)
File "/.../anaconda3/lib/python3.7/site-packages/flax/nn/base.py", line 392, in init_by_shape
return jax_utils.partial_eval_by_shape(lazy_init, input_specs)
File "/.../anaconda3/lib/python3.7/site-packages/flax/jax_utils.py", line 94, in partial_eval_by_shape
output_shapes = jax.eval_shape(lazy_fn, *input_structs)
File "/.../anaconda3/lib/python3.7/site-packages/jax/api.py", line 2023, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
File "/.../anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 256, in abstract_eval_fun
instantiate=True)
File "/.../anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 337, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/.../anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 152, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/.../anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 152, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/.../anaconda3/lib/python3.7/site-packages/flax/jax_utils.py", line 88, in lazy_fn
master = leaves[0]._trace.master # pylint: disable=protected-access
File "/.../anaconda3/lib/python3.7/site-packages/jax/core.py", line 365, in __getattr__
attr = getattr(self.aval, name)
AttributeError: 'ShapedArray' object has no attribute '_trace'
I believe I have the latest Jax and Flax versions:
Flax: 0.0.1a0
Jax : 0.1.58
Thanks for your help :)
Investigate and fix cause of warnings defined in pytest.ini
With pytest.ini defined, all the warnings captured during a test execution are reported as exceptions.
Although there are a few warnings which are independent of flax's implementation (for e.g: DeprecationWarning and UserWarning captured due to importing tensorflow.compat.v2.io import gfile).
Ideally we should have ZERO entries as exemptions in pytest.ini. Investigate the root cause and fix them.
Non fully reproducible results on GPU
Although random key is fixed (e.g. jax.random.PRNGKey(0)
), the results of different runs are always different.
My question is how one can fix the random behavior? Because my expectation is when I choose a fixed random key, all the runs should produce the same result.
Thank you in advance.
Use the following code to reproduce the issue (I simply take the MNIST example with shuffle
removed):
import jax
import flax
import numpy as onp
import jax.numpy as jnp
import tensorflow_datasets as tfds
class CNN(flax.nn.Module):
def apply(self, x):
x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
x = jax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
x = flax.nn.relu(x)
x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = flax.nn.Dense(x, features=256)
x = flax.nn.relu(x)
x = flax.nn.Dense(x, features=10)
x = flax.nn.log_softmax(x)
return x
@jax.vmap
def cross_entropy_loss(logits, label):
return -logits[label]
def compute_metrics(logits, labels):
loss = jnp.mean(cross_entropy_loss(logits, labels))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return {'loss': loss, 'accuracy': accuracy}
@jax.jit
def train_step(optimizer, batch):
def loss_fn(model):
logits = model(batch['image'])
loss = jnp.mean(cross_entropy_loss(
logits, batch['label']))
return loss, logits
optimizer, _, _ = optimizer.optimize(loss_fn)
return optimizer
@jax.jit
def eval(model, eval_ds):
logits = model(eval_ds['image'] / 255.0)
return compute_metrics(logits, eval_ds['label'])
def train():
train_ds = tfds.load('mnist', split=tfds.Split.TRAIN)
train_ds = train_ds.cache().batch(128)
test_ds = tfds.as_numpy(tfds.load(
'mnist', split=tfds.Split.TEST, batch_size=-1))
_, model = CNN.create_by_shape(
jax.random.PRNGKey(0),
[((1, 28, 28, 1), jnp.float32)])
optimizer = flax.optim.Momentum(
learning_rate=0.1, beta=0.9).create(model)
for epoch in range(10):
for batch in tfds.as_numpy(train_ds):
batch['image'] = batch['image'] / 255.0
optimizer = train_step(optimizer, batch)
metrics = eval(optimizer.target, test_ds)
print('eval epoch: %d, loss: %.4f, accuracy: %.2f'
% (epoch+1,
metrics['loss'], metrics['accuracy'] * 100))
train()
train()
Language Model Notebook Broken
This is just a tracking bug report - the TPU language model colab notebook is currently broken due to a low-level change in jaxlib that's interacting badly w. the colab tpu backends. We should be able to fix it soon.
What am I missing?
uhmm...
432 stars....
featured on trending...
empty repository....
2 issues...
what am i missing?
Simplify the CIFAR10 example
Currently, the CIFAR10 example implements 2 architecture (Wide ResNet and PyramidNet), 2 regularization methods (Shake-shake and shake-drop), several learning rate schedules and allows training various combinations of these. This makes the example large and unnecessarily complex.
- Can we just keep one or two of the above combinations?
- Which ones are more relevant / useful / educational?
- Should we turn some other combinations into HOW-TOs? Which ones?
Flattening parameters
Hi,
Great package! I like the syntax.
Is there an easy way to pack and unpack parameters for a flax.nn.Module
? I saw this in the optimizer code:
new_params = jax.tree_unflatten(treedef, new_params_flat)
new_param_states = jax.tree_unflatten(treedef, new_states_flat)
as well as this:
def init_state(self, params):
sub_states = []
for traversal, opt in zip(self.traversals, self.sub_optimizers):
params_t = list(traversal.iterate(params))
state = opt.init_state(params_t)
sub_states.append(state)
but was wondering if there was a simpler way to convert the parameters to a single array (which I can still compute gradients on)? I'm interested in doing some operations on the parameters as a single vector as part of the loss function.
Thanks,
Miles
Example: DQN
I would like to add a DQN example, is there interest in it?
The idea is to use only flax and JAX, being independent of rlax
Clarify docstring on nn.Collection
Improve documentation
Let's use this issue to collect all the places in which the documentation should be improved. If you think there's something missing from the list, please comment and we'll update the list.
As we work through this list, we'll update the description with links to PRs addressing each of the items.
- nn.Collection
- Module.apply
- Module.call
- Module should get a docstring
- nn.stateful
- nn.stochastic
- Module.get_param (explain when it should be used)
Project missing LICENSE FILE
De-duplicate various pmean and psum implementations
The implementations in flax.training.common_utils and flax.jax_utils should be removed in favor of using the ones in jax.lax directly.
nevermind
nevermind
A flax module without trainable parameter changes the rng
In a flax module, I turned a function that applies a bunch of numpy operations to a flax (sub-)module that has no trainable parameter, e.g.:
def add(x, y):
return x + y
to
@nn.module
def add(x, y):
return x + y
I noticed that I get different numbers at the end. Seems this is because the submodule changes the random number generators by splitting the rng of the parent module, while this behaviour is not expected in this case.
Maybe it makes more sense to condition splitting the rng on the existence of trainable parameters in the submodule?
FLIP: rm 'shape' from Module.param call signature
Goals
In our current implementation, initializer
methods (which get passed to Module.param
) are required to have the following signature:
def initializer(rng_key, shape):
# ...
return initialized_parameters
and in Module.param
we assert that initialized_parameters.shape == shape
.
Sometimes an initializer needs more (or less) information than the shape of its output, and at the moment this is achieved by writing the initializer function within a Module
definition so that it can close over other data that it requires. For example, weightnorm initialization, which is data-dependent, can be implemented as follows, note the necessity of the dummy shape arguments:
class Conv2D(nn.Module):
def apply(self, inputs, features, kernel_size):
strides = (inputs.ndim - 2) * (1,)
conv = partial(
lax.conv_general_dilated, window_strides=strides, padding='VALID',
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
in_features = inputs.shape[-1]
kernel_shape = kernel_size + (in_features, features)
def initializer(key, shape):
# A weightnorm initializer generating a (direction, scale, bias) tuple.
# Note that the shape argument is not used.
direction = nn.initializers.normal()(key, kernel_shape)
unnormed_out = conv(inputs, _l2_normalize(direction))
mean = np.mean(unnormed_out, (0, 1, 2))
var = np.std (unnormed_out, (0, 1, 2))
return dict(direction=direction, scale=1 / var, bias=-mean / var)
# We feed in None as a dummy shape argument to self.param. Currently
# Module.params assumes that the initializer takes in a shape argument;
# None acts as a flag to avoid shape checking.
params = self.param('weightnorm_params', None, initializer)
direction, scale, bias = [params[k] for k in ('direction', 'scale', 'bias')]
return conv(inputs, _make_kernel(direction, scale)) + bias
This situation isn't terrible, but it does highlight the fact that the assumption that initializers depend on parameter shape and nothing else is a bit arbitrary.
A more flexible API, with initializer
requiring only a JAX PRNG key, would mean more consistent implementations of different types of initializers, and might also help to correctly emphasize what Flax's role is here, namely to handle splitting and passing of PRNG keys to initializers and to setup the parameters data-structure (a nested dictionary).
Proposal
We propose to change the call signature of Module.param
from
param(self, name: str, shape: Shape, initializer: Callable[[Key, Shape], Array]):
to
param(self, name: str, initializer: Callable[[Key], Array]):
This change would lead to a slight simplification of the weightnorm example above, since the dummy shape arguments could be removed.
For existing Modules for which the initializer is a straightforward function of the parameter shape and no other data is required, we can alter the currying of the initializer definitions so that lines like
kernel = self.param('kernel', kernel_shape, kernel_init)
can be replaced by
kernel = self.param('kernel', kernel_init(kernel_shape))
There may be downsides to this approach which I, being a relative Flax noob, am unaware of. One thing is that we'd lose the shape checking in Module.params
, but that seems like the kind of check which should be part of a test anyway.
Alternatives
I think the obvious alternative is to simply keep the current API. The change proposed above is relatively minor but it still would likely require a number of users to make changes to their own code.
FLIP: Make module instances semantically meaningful by not overriding `module.__new__`
Introduction
Currently, while Flax modules are defined by subclassing flax.nn.Module
, those modules don't behave the same way that normal Python objects behave.
One of the large differences is that Flax Modules override __new__
, meaning that module instances aren't a semantically meaningful thing in Flax at the moment. Right now, in Flax, what looks like module construction (nn.Dense(x, features=10)
) actually does two things:
- Construct an object of type
nn.Dense
(using the non-documented APImodule.new_instance()
) - Call the
apply
method on that instance and return it.
Some upsides of the current approach are:
- Modules are defined as a single function, as opposed to, e.g. the style of other libraries, such as Haiku, where you need to scroll up and down between
__init__
and__call__
to fully understand what a module does. - Calls to submodules are very concise, e.g.
nn.Dense(x, features=10)
.
Some downsides of the current approach are:
- In order to reuse a module, you must use the
module.shared()
abstraction which has a confusing mental model -- what doesmodule.shared()
return? A module class? A module instance? Moreover, which arguments must be passed intomodule.shared()
in order for the shared module to be usable? (Behind the scenesshared
is implemented on top ofpartial
) - You can't instantiate a module directly, outside of another module. This leads to surprising things like
new nn.Model(nn.Dense.partial(features=10), params)
-- why do we need to usepartial
to instantiate a Model? What type does the first argument tonn.Model
have? Is it a module class? Module instance? - In a few spots in
flax/nn/base.py
there is code that does "kwarg mangling". Some of these code had bugs before. It would be nice to reduce the need for kwarg mangling. - In order to support multiple methods on a module, the
module_method
decorator turns methods that aren'tapply
into new Modules. This is surprising, for example how would I do the equivalent ofmodule.call(params, *args)
but to call a methodfoo
that's notapply
? That would bemodule.foo.call(params, *args)
. That's a pretty surprising mental model. - Wanting to define shared parameters or submodules that work across multiple methods requires either using non-traditional patterns and/or with more complexity in Flax core (see discussion on #161)
apply
was a special-cased method on modules.
Proposal
- No longer override
__new__
in Modules - Eliminate
.partial()
- Potentially eliminate
.shared()
(though we may choose to keep it as a safeguard -- see below) - Split up current module's
apply
methods into the controlled use of Python 3.7 dataclasses (for storing module hyperparameters) and a "vanilla Python"__call__
method (or actually, any name you want) that only takes in the module input(s) - This may even allow for module instance to directly refer to a read-only version of their parameters, e.g.:
class Foo(Module):
def __init__(x):
dense = nn.Dense(features=16)
x = dense(x)
# `dense.params` is defined here; maybe also `dense.params.kernel` and `dense.params.bias`
For example, a simple Dense layer may look like this:
@dataclass
class Dense(Module):
features: int
kernel_init: Callable = initializers.lecun_normal()
bias_init: Callable = initializers.zeros
def __call__(self, x):
"""Applies a linear transformation to the inputs along the last dimension."""
kernel = self.param('kernel', (x.shape[-1], self.features), self.kernel_init)
bias = self.param('bias', (self.features,), self.bias_init)
return jnp.dot(x, kernel) + bias
Then, an MLP would look like this:
class MLP(Module):
def __call__(self, x):
x = nn.Dense(features=16)(x)
x = nn.relu(x)
x = nn.Dense(features=16)(x)
I believe that this proposals keeps the conciseness of current Flax, while having the potential to significantly reduce both implementation complexity and mental model complexity. The mental model in Flax now reduces to the same one as Keras (other than the fact that parameters are immutable)
For example, in this case re-using a module is trivial -- keep a reference to nn.Dense(features=16)
and re-use that. (NOTE: We may choose to keep the safe-guarding behavior of .shared()
that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when __call__
is invoked a second time, unless .shared()
was called on the module instance first)
With this proposal, there's also no need for module.partial
-- you can just use functools.partial(module.__call__)
or functools.partial(module)
. (Though this is a bit different than in current Flax because the return value of functools.partial
in itself isn't a module, rather it's a function. But maybe it was always confusing to understand module.partial
-- does it override kwargs for all module methods? Just apply?)
Possible transition plan
Given the non-trivial amount of code written using Flax, and the fact that this proposal would change every module written with Flax, we need an upgrade plan.
I propose adding, alongside every new module in flax.nn
, a function with the same name but lower-cased, that operates the same as in current Flax. These functions would be deprecated-on-arrival. E.g., alongside Dense
as shown above we would also have
def dense(x, features, kernel_init, bias_init):
"""DEPRECATED. Use the new Module API: http://link/to/upgrade/guide."""
return Dense(features, kernel_init, bias_init)(x)
Then the first part of the upgrade process is mainly search and replace "Dense" -> "dense", etc.. In addition, some more manual changes will possible be necessary for uses of .partial
and .shared
. Later, users can transition into the new API at a time they see fit.
Current State
@avital has a messy work-in-progress branch checking the viability of using dataclasses in this settings. Results so far seem cautiously promising, but more work is needed before this proposal is ready to be acted on.
Full end-to-end MNIST example gives error.
I copied and pasted Full end-to-end MNIST example
code then run train()
and I got error.
Colab environment packages versions:
- TensorFlow: '2.1.0'
- Jax: '0.1.58'
- Numpy: '1.17.5'
Error:
AttributeError Traceback (most recent call last)
<ipython-input-4-2da0ffaf5447> in <module>()
----> 1 train()
12 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in __getattr__(self, name)
363
364 try:
--> 365 attr = getattr(self.aval, name)
366 except KeyError:
367 raise AttributeError(
AttributeError: 'ShapedArray' object has no attribute '_trace'
Example: pix2pix
Project missing license file
FLIP: Support __init__ in Modules
Introduction
By default Modules are defined using only an apply function and unshared parameters and submodules. This makes it easy to write modules (with control flow) in a concise and readable way.
However, some modules don't fit well in this abstraction. For example consider an autoencoder. During model training we would like to take an observed example and encode and decode it. However, we would also like to be able to call the encode
and decode
procedures as separate methods for other use cases besides training.
With the current Flax api a simple AutoEncoder can be written as follows:
from flax import nn
class AutoEncoder(nn.Module):
def _create_modules(self, encoder_features, decoder_features):
encoder = nn.Dense(features=encoder_features, name='encoder')
decoder = nn.Dense(features=decoder_features, name='decoder')
return encoder, decoder
def apply(self, x, **hparams):
encoder, decoder = self._create_modules(**hparams)
z = encoder(x)
return decoder(z)
@nn.module_method
def encode(self, x, **hparams):
encoder, _ = self._create_modules(**hparams)
return encoder(x)
@nn.module_method
def decode(self, z, **hparams):
_, decoder = self._create_modules(**hparams)
return decoder(z)
A number of issues can be noticed in this examples:
- hyper parameters need to be passed around manually from all module methods
_create_modules
behaves a lot like a constructor but also needs to be called manually- we cannot directly call the module methods
encode
anddecode
from apply leading to even more code duplication
Proposal
We would like to improve the syntax of modules that require multi methods and reuse of parameters, sub modules, and hyperparameters across various methods.
The proposed syntax allows us to rewrite the AutoEncoder
module as follows
class AutoEncoder(nn.Module):
def setup(self, encoder_features, decoder_features, **kwargs):
self.encoder = nn.Dense.shared(features=encoder_features, name='encoder')
self.decoder = nn.Dense.shared(features=decoder_features, name='decoder')
return kwargs
def apply(self, x):
z = self.encode(x)
return self.decode(z)
@nn.module_method
def encode(self, x):
return self.encoder(x)a
@nn.module_method
def decode(self, x):
return self.decoder(x)
model_def = AutoEncoder.partial(encoder_features=4, decoder_features=8)
_, params = model_def.init(rng, x)
model = nn.Model(model_def, params)
# use apply function for training
x_recon = model(x)
# two step encode+decode
z = model.encode(x)
x_recon = model.decode(z)
A few differences w.r.t. to the introduction example:
- a constructor (
setup
) defines shared modules and assigns them to fields. - the constructor defines the hyperparameters and they are no longer passed around by other methods.
- apply reuses the module methods avoid code duplication.
A few changes are required to make the new syntax work
-
When a Module is constructed we must first call the
setup
function. Thesetup
function receives allkwargs
and returns the remaining keyword arguments that should be passed to the module method. -
when calling a module_method using
self.some_module_method(...)
it behaves as an ordinary python method.
An implementation of this proposal lives in draft PR #104
Alternatives
The main issue in this proposal is determining which arguments are passed to setup
. There are a few variations that can be considered:
-
Introspection is used to determine which keyword arguments belong to setup.
-
Require users to provide a list of construction arguments
-
Pass all keyword argument to
setup
. This woud make the implementation easier but would require most module methods to include something like**unused_kwargs
to work correctly. -
[CURRENT PROPOSAL]
setup
receives all keyword arguments and returns a dictionary of keyword arguments that should be passed to the apply method and other module methods
[Question] Parameters
Hi, I am curious how Flax deals with parameters given you see code like this
def apply(self, x):
x = Conv(x, ...)
where a new instance of the module Conv
is created every time during apply. Even more curious about how parameter sharing will be approached. I think it will be important to explain this because all pytorch/tf.keras users would expect these layer to be instantiated during __init__
and use in apply
.
FLIP: dtype API
Goals
The high level goal of this proposal is to provide a consistent api for dealing with precision of parameters and computation for the various floating point types currently available in Jax.
The default precision is currently float32.
CPUs & GPUs also support double precision (float64).
Half precision types are more complicated because GPUs and TPUs support float16 and bfloat16 respectively. Both have reduced precision and float16 additionally has a reduced range compared to float32.
- Built in modules should respect the dtype of its inputs
- Avoid automatically using half precision for computations with known numerical instability
- It should be possible to control the dtype of parameters separately from the dtype of computations
Proposal
To determine the dtype of the computation we use the dtype of the most precise input which we call the input dtype.
-
initializers default to float32 precision but Modules should work with half or double precision too. Users can pass initializers that return a ndarray with (b)float16 or float64 dtypes. Similairly, users can decide to cast all parameters to a different dtype after initialization.
-
if the input dtype is float64 all computation is done in float64.
-
if the input dtype is float32 all computation is done in float32. All Modules are currently implemented and tested using float32 as the default. No surpises here.
-
if the input dtype is (b)float16 all outputs are stored in (b)float16 and matrix computations (conv & dot) are performed in (b)float16. The garantuees are intentionally minimal here. We only guarantee matrix computations to run at half precision. This fits in well with how the current hardware accelerators look like. Both modern GPUs and TPUs have special hardware for accelerating matrix (like) computations and the computational intensity of matrix multiples makes them compute bound such that we can effectively use the increased flops. Storing (intermediate) outputs opens up the secondary benefit of these half precision types which is reduced memory consumption.
Alternatives
-
use a dtype keyword argument (current api). The downside of this is that it is pretty verbose and might suggest to the user that we will stricly adhere to doing all computations with the given dtype. It is also easy to accidentally forget passing the dtype to a submodule which results in a silent error.
-
add dtype arguments for parameters. This could end up being very verbose and error prone just like (1). Another downside is that the current jax random ops don't have good support non-default dtypes often resulting in errors.
-
execute all computation in given dtype irrespective of numerical stability. The benefit of this approach is that it results in more explicit and simpler APIs with the big downside that the Modules library will support combinations of Modules and dtypes that are likely to result in numerical unstability. This could also lead to premature optimization given that many computations on modern accelerators are bottlenecked by memory not compute so low precision compute does not automatically lead to increased performance.
nn.stochastic context does not extend into jitted functions
This causes silent failures where nn.make_rng()
inside of the jitted function always produces the same "random" numbers. This issue should be well-documented and errors should be thrown to protect against it.
Create near state of the art object detection example
An object detection example would be extremely useful. A very good single-stage detection model that runs on Microsoft COCO would be an extremely useful starting point for flax users who need to do detection.
The single-stage detection models get pretty good results these days and should be simpler to implement, but it would be good if someone who is familiar with the detection literature could comment on this issue and suggest something that isn't too far from the state of the art (at least as good or better than RetinaNet).
[Question] Best way to implement stochastic reinforcement learning actors?
Hello!
First of all, thanks for this project - it is a lifesaver! So I wanted to get familiar with JAX so I decided to implement a few deep reinforcement learning algorithms as a side project. I initially approached the problem by subclassing flax.nn.Module as follows:
import flax.nn as nn
import jax.numpy as jnp
import jax.random as random
class MLPCategoricalActor(nn.Module):
def apply(self, obs, act, action_space=None, rng=None,
hidden_sizes=(64, 64), activation_fn=nn.tanh, output_fn=None):
assert action_space is not None, "Action space must be specified."
if rng is None:
rng = nn.make_rng()
act_dim = action_space.n
logits = _MLP(obs, sizes=list(hidden_sizes) + [act_dim], activation_fn=activation_fn)
pi = random.categorical(rng, logits)
logp_all = nn.log_softmax(logits)
logp = jnp.multiply(one_hot(act, act_dim), logp_all).sum(axis=1)
logp_pi = jnp.multiply(one_hot(pi, act_dim), logp_all).sum(axis=1)
return pi, logp, logp_pi
I very quickly ran into the problem of having a duplicate parameter 'rng'. I dug into the flax code and discovered that 1) dropout was not implemented as a module as I believed and 2) the ModuleFrame has an rng param that I can't access (apparently). I came up with three solutions:
-
Get
rng
from thenn.stochastic
context, but that would require wrapping the entire training function with it which seems a little weird to me. -
Use the same solution as in the VAE example and pass
rng
each time as a positional argument. -
Mix both solutions and try to get
rng
from a kwarg and if that fails fallback to the context. This may lead to a problem if someone sets the kwarg with a call topartial
...
I wanted to ask you how you would go about this? My main concerns are code reusability and reproducibility.
Use module_method for RNNCellBase.initialize_carry
The RNNCellBase interface requires passing consistent arguments to both initialize_carry and apply. if the caller isn't careful to pass the same args (e.g. cell_size) to both methods, things will immediately break. If these hyperparams were instead specified only once, users couldn't make that mistake.
(This isn't an issue for the existing cells because they infer cell sizes from the carry argument, but this issue can come up for more complex cells.)
If initialize_carry was defined as a module_method, shared hyperparameters could be applied using .partial()
This would also make it possible for an RNN to learn its initial state.
Change "howto-" branches to "howto/"
`Module.partial().__qualname__` is not copied
I don't know if this is expected but Module.partial
do not seem to forward __qualname__
.
To reproduce:
class MyModule(flax.nn.Module):
def apply():
pass
print(MyModule.partial().__name__) # MyModule
print(MyModule.partial().__qualname__) # Module.partial.<locals>.PartialModule
Incorrext init for LSTM
The code below appears incorrect - I may be misunderstanding though. As per the comment you're summing 2 dense layers so one has no bias - but dense_h has bias=False and you're passing the bias_init, dense_i has bias=True
but is using the default bias_init? Should be the other way around or pass bias_init to both?
Since the default bias_init is zeros for both Dense.apply() and LSTMCell.apply() it has no impact unless a different bias_init is supplied?
model and optimizer.target diff
I am working on pix2pix example. I don't get an error when I give the picture to the model
, but when I give the picture to model_optimizer.target
I am getting error. Why?
Code
import jax
import flax
import numpy as onp
import jax.numpy as jnp
OUTPUT_CHANNELS = 3
class DownSample(flax.nn.Module):
def apply(self, x, features, size, apply_batchnorm=True):
x = flax.nn.Conv(x, features=features, kernel_size=(size, size), strides=(2, 2), padding='SAME', bias=False)
if apply_batchnorm:
x = flax.nn.BatchNorm(x)
x = flax.nn.leaky_relu(x)
return x
class UpSample(flax.nn.Module):
def apply(self, x, features, size, apply_dropout=True):
x = flax.nn.ConvTranspose(x, features=features, kernel_size=(size, size), strides=(2, 2), padding='SAME', bias=False)
x = flax.nn.BatchNorm(x)
if apply_dropout:
x = flax.nn.dropout(x, 0.5)
x = flax.nn.relu(x)
return x
down_list = [[64, 4, False],
[128, 4],
[256, 4],
[512, 4],
[512, 4],
[512, 4],
[512, 4],
[512, 4]]
up_list = [[512, 4, True],
[512, 4, True],
[512, 4, True],
[512, 4],
[256, 4],
[128, 4],
[64, 4]]
class Generator(flax.nn.Module):
def apply(self, x):
skips = []
for down in down_list:
x = DownSample(x, *down)
skips.append(x)
skips = list(reversed(skips[:-1]))
for up, skip in zip(up_list, skips):
x = UpSample(x, *up)
x = jnp.concatenate((x,skip))
x = flax.nn.ConvTranspose(x, features=OUTPUT_CHANNELS, kernel_size=(4,4), strides=(2,2), padding='SAME')
x = flax.nn.tanh(x)
return x
def create_model(key, batch_size, image_size, model_def):
input_shape = (batch_size, image_size, image_size, 3)
with flax.nn.stateful() as init_state:
with flax.nn.stochastic(jax.random.PRNGKey(0)):
_, initial_params = model_def.init_by_shape(key, [(input_shape, jnp.float32)])
model = flax.nn.Model(model_def, initial_params)
return model, init_state
def create_optimizer(model, learning_rate, beta):
optimizer_def = flax.optim.Adam(learning_rate=learning_rate,
beta1=beta)
optimizer = optimizer_def.create(model)
optimizer = flax.jax_utils.replicate(optimizer)
return optimizer
key = jax.random.PRNGKey(0)
generator_model, generator_state = create_model(key, 1, 256, Generator)
generator_optimizer = create_optimizer(generator_model, 2e-4, 0.5)
test_input = jax.random.normal(jax.random.PRNGKey(1), (1, 256, 256, 3))
with flax.nn.stochastic(jax.random.PRNGKey(0)):
prediction = generator_model(test_input) # work with no error
print('prediction ok')
with flax.nn.stochastic(jax.random.PRNGKey(0)):
prediction_opt = generator_optimizer.target(test_input)
Error
ValueError: Existing shape (1, 4, 4, 3, 64) differs from requested shape (4, 4, 3, 64)
Related pr: #186
Thank you!
Error when JITting `Model.__call__`
eg
import jax
from flax import nn
layer=nn.Dense.partial(features=1)
key=jax.random.PRNGKey(0)
x=jax.random.normal(key, (20, 2))
_,params=layer.init(key, x)
layer_m=nn.Model(layer, params)
jax.jit(layer_m)(x)
errors with
TypeError Traceback (most recent call last)
<ipython-input-2-2e4e0581e3f5> in <module>
6 _,params=layer.init(key, x[0,...])
7 layer_m=nn.Model(layer, params)
----> 8 jax.jit(layer_m)(x)
~/opt/anaconda3/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
148 flat_fun, out_tree = flatten_fun(f, in_tree)
149 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 150 name=flat_fun.__name__)
151 return tree_unflatten(out_tree(), out)
152
~/opt/anaconda3/lib/python3.7/site-packages/jax/linear_util.py in __name__(self)
121 @property
122 def __name__(self):
--> 123 return getattr(self.f, '__name__', '<unnamed wrapped function>')
124
125 def wrap(self, gen, gen_static_args, out_store) -> 'WrappedFun':
~/opt/anaconda3/lib/python3.7/site-packages/flax/nn/base.py in __getattr__(self, name)
897 def __getattr__(self, name):
898 value = getattr(self.module, name)
--> 899 if issubclass(value, Module):
900 def wrapper(*args, **kwargs):
901 return value.call(self.params, *args, **kwargs)
~/opt/anaconda3/lib/python3.7/abc.py in __subclasscheck__(cls, subclass)
141 def __subclasscheck__(cls, subclass):
142 """Override for issubclass(subclass, cls)."""
--> 143 return _abc_subclasscheck(cls, subclass)
144
145 def _dump_registry(cls, file=None):
TypeError: issubclass() arg 1 must be a class
Make `ModelParamTraversal` more public?
ModelParamTraversal
is currently somewhat hidden within optim
. But it is much more generally useful, for example for implementing weight-decay (not as a loss) or weight standardization or spectral norm (I think).
So it seems like putting it in traverse_util.py
(where I'd look for it) would make sense.
TransformerLM lm1b model accept float and out-of vocabulary token as input
I was experimenting with the TransformerLM
from the lm1b example (https://github.com/google/flax/blob/master/examples/lm1b/models.py).
I noticed that TransformerLM
do not raise any error when:
- The input is float (e.g.
np.float32
) - The input ids are outside the range
[0, vocab_size]
Both of those seems confusing to me. Is it the expected behavior ?
Example:
from flax.examples.lm1b import models as lm1b_models
model_cls = lm1b_models.TransformerLM.partial(vocab_size=32)
y, params = model_cls.init_by_shape(
jax.random.PRNGKey(0),
[((1, 3), jnp.float32)], # < Tracing with float don't raise error
)
model = flax.nn.Model(model_cls, params)
model(jnp.array([[0.3, 0.5, 2.9]])) # < Float don't raise error
model(jnp.array([[100, 101, 102]])) # < Out of vocabulary token ids don't raise error
As a side note, I noticed that the input pipeline is using the deprecated TFDS sub-split API: tfds.Split.TRAIN.subsplit
which is incompatible with the last versions of TFDS (split='train[90%:]'
should be used instead).
(minor cleanup) I also think
dataset_builder = tfds.builder(dataset_name, data_dir=data_dir)
...
train = tfds.load(dataset_name, data_dir=data_dir, split='train')
valid = tfds.load(dataset_name, data_dir=data_dir, split='test')
could be rewritten as:
builder = tfds.builder(dataset_name, data_dir=data_dir)
builder.download_and_prepare() # No-op if data already exists
...
train = builder.as_dataset(split='train')
valid = builder.as_dataset(split='test')
This would avoid reloading three time the dataset, which may save a few seconds startup when loading from a remote file systems.
"Manual" updates to module parameters
I'm trying to implement a projected gradient descent method and need to apply a constraint to parameters of a nn.Module
. For example, apply jnp.abs()
to a specific parameter. I haven't been able to find a simple way to do this yet, let me know the there's a canonical approach.
Documentation for recurrent
I'm studying RNN's using jax so I'm currently investigating flax. I think the documentation in the RNN module is incorrect or out of date.
Results in TypeError: apply() missing 1 required positional argument: 'inputs'
Also create
builds and evaluates the model and returns a (y, model), so I feel like the design has changed and the recurrent examples should either initialise the state before calling create
(but that wouldn't scan
), or call create_by_shape
?
Edit: I found a test which seems to confirm that the docstring is incorrect i.e. the code below creates an initial carry and passes to create
2nd Edit:
Also, I'm slightly confused by LSTMCell.initialize_carry() - it requires a batch_dim, and returns an initialised (zero) state for each batch. I might be missing something but this seems to preclude using lax.scan() to process each batch sequentially using the state from the previous batch as the initial state for the next batch (or some other state estimator which is specifically what I'm attempting). For example I have 365 "trajectories" (timeseries) each consisting of 24 samples and 5 features. So the state should be size 5 and I want to scan each trajectory from some initial state I provide (the intent is to use another net to estimate the state).
init_by_shape does not work on pytrees
init_by_shape
only supports a list of arrays as lazy arguments.
Instead it would be better to support arbitrary pytrees.
The easiest way to support this is by using the ShapeDtypeStruct in Jax similar to jax.eval_shape
.
apply_gradient with no parameters gives ValueError
This issue is admittedly a corner case, but one we've run into. If we consider the following flax.nn.Module
:
class Identity(flax.nn.Module):
def apply(self, x):
return x
We won't be able to call apply_gradient
since the output from this line will be an empty list.
This should probably (?) be addressed since it's exceptional behavior that may surprise, but could see arguments for different ways of resolving. One simple answer is to just no-op, but there might be some higher-level concerns I'm not thinking about which say we don't even want parameterless modules (in which case, raise on construction).
Anyway, we've resolved for now by just adding a dummy parameter. Here's the full minimum example and the resulting value error:
import flax
import jax
import jax.numpy as jnp
class Identity(flax.nn.Module):
def apply(self, x):
return x
model_def = Identity.partial()
_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1,)])
model = flax.nn.Model(model_def, params)
def loss_fn(model, x, y):
y_hat = model(x)
return jnp.square(y - y_hat).mean(), y_hat
optim_def = flax.optim.Adam(learning_rate=1.0)
optimizer = optim_def.create(model)
(loss, y_hat), grad = jax.value_and_grad(loss_fn, has_aux=True)(optimizer.target, 1.0, 2.0)
optimizer.apply_gradient(grad)
~/src/flax/flax/optim/base.py in apply_gradient(self, hyper_params, params, state, grads)
135 for param, state, grad in zip(params_flat, states_flat, grads_flat)]
136
--> 137 new_params_flat, new_states_flat = list(zip(*out))
138 new_params = jax.tree_unflatten(treedef, new_params_flat)
139 new_param_states = jax.tree_unflatten(treedef, new_states_flat)
ValueError: not enough values to unpack (expected 2, got 0)
Possible idea for Sequential implementation?
Wasn't sure if there was some philosophical reason not to have Sequential (many of these points might apply!), but I found having this simple abstraction useful. It's...not the prettiest, I admit.
class Sequential(nn.Module):
def apply(self, x, modules, args):
results = x
for module, arg in zip(modules, args):
results = module(results, **arg)
return results
And the way you might use is
model_def = Sequential.partial(modules=[Identity, Plus], args=[{}, {"z": 2}])
ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 2)])
model = nn.Model(model_def, params)
model(np.array([1,2]))
> DeviceArray([3, 4], dtype=int32)
As a complete example with the dummy modules:
class Identity(nn.Module):
def apply(self, x):
return x
class Plus(nn.Module):
def apply(self, x, z):
return x + z
class Sequential(nn.Module):
def apply(self, x, modules, args):
results = x
for module, arg in zip(modules, args):
results = module(results, **arg)
return results
model_def = Sequential.partial(modules=[Identity, Plus], args=[{}, {"z": 2}])
ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 2)])
model = nn.Model(model_def, params)
model(np.array([1,2]))
The guided tour should not use the deprecated `optimizer.compute_gradient()` API
See https://flax.readthedocs.io/en/latest/notebooks/flax_guided_tour.html (the notebook can be found here: https://github.com/google/flax/blob/master/docs/notebooks/flax_guided_tour.ipynb).
Instead, just call jax.grad
or jax.value_and_grad
directly.
VAE example uses deprecated `optimizer.optimize()`
Re-shuffle data each iteration
Some example (e.g. nlp_seq
) show clear patterns in their training loss (see TensorBoard).
This is probably because the training data isn't shuffled enough. We should consider adding shuffle
with reshuffle_each_iteration=True
to the train IO pipelines of existing examples.
Review and update all TFDS pipelines to use modern splitting API.
As noted in #136 the lm1b example we're still using the deprecated TFDS splitting API, we should review and update all of our pipelines to use the modern version. This is the tracking issue for those fixes.
Add tests to all examples that one step of training works without crashing.
Optimize multiple models at once
Curious how I can use Flax to optimize two models with the same loss function.
I have an encoder-decoder setup where I'm using teacher forcing to train a decoder LSTM. This requires doing one encoding step of an input, then a lot of incremental decoding steps where loss is added after each one. I plan on doing this by having a distinct encoder and decoder network. Any quick intuitions for how I can get Flax to optimize both simultaneously? It looks like optimizers only take in a single object to optimize.
Thanks!
Fully deprecate `optimizer.compute_gradient`
We're replacing the use of optimizer.compute_gradient
with calls to JAX's jax.grad
or jax.value_and_grad
.
- The docstring for
compute_gradient
should specify that it's deprecated and print a warning, like foroptimizer.optimize()
. - The ImageNet example shouldn't use
compute_gradient
[Example] GraphSage with cora|citeseer|pubmed
Currently the GNN model uses a GCN model trained on the Zachary's karate club dataset. However, both the model and the dataset are simplified, e.g., GCN is the most basic GNN layer and the dataset used can be simply represented as an edge_list.
I'd like to write a GraphSage model with one of the titled datasets to concrete the examples under the GNN section.
BatchNorm error message
The flax.nn.normalization.BatchNorm
error message "batch_stats should be provided if use_running_averages is True" is slightly misleading.
I think would be clearer if it say something like "when use_running_averages is True either use a stateful context or provide batch_stats"
It wasn't clear to me why the imagenet
example worked without providing batch_stats and my layer failed until I looked closer at the source and realised that it was using a stateful context.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.