Giter VIP home page Giter VIP logo

flaxformer's Introduction

Flaxformer: transformer architectures in JAX/Flax

Flaxformer is a transformer library for primarily NLP and multimodal research at Google. It is used for many NLP research use cases, providing both off-the-shelf BERT and T5 models, and several research projects built on shared components.

General library goals

The Flaxformer library aims to provide transformer models that are:

  • High performance: Models are annotated for use with the PJIT API, enabling them to be used for training the largest models.
  • Reusable: Components have self-contained configuration, and high-level modules like encoders, decoders, etc. don't make too many assumptions about what their sub-modules look like.
  • Tested: We aim to employ a reasonable amount of unit testing, and write tests whenever bugs are encountered. However no guarantees are provided.
  • Maintainble: We have created a versioning strategy for our modules so code refactors can take place which alter the module structure. This is tricky in Flax, because Flax generates a tree of parameters based on the exact module structure. Our approach lets us maintain compatibility with previously trained model checkpoints.

Code locations

Modeling components such as dense attention, layer norms, and MLP blocks can be found in the components/ directory.

Higher-level classes which combine these components can be found in the architectures/ directory. The current architecture file for the T5 family of models is architectures/t5/t5_architecture.py; this is a mid-level API requiring sub-components to be configured. A high-level starting point, exposing fewer parameters, is architectures/t5/t5_1_1.py.

Relationship to other codebases

Flaxformer is primarily used by other research projects, in particular T5X. We hope to release examples demonstrating the integration of these codebases soon.

If you would like to use Flaxformer independently of T5X, please see the unit tests for examples instantiating the models. In the medium-term future, we hope to provide more stand-alone examples of Flaxformer use.

Contributions

Unfortunately, we cannot accept contributions to the Flaxformer repo at this time, so any pull requests will be automatically closed - but please file issues as needed!

Installing dependencies and running tests

First, we recommend installing a few dependencies manually,

pip3 install numpy sentencepiece tensorflow>=2.14.0

This is a workaround to prevent pip backtracking on package versions; we believe there is either a version conflict in upstream packages, or pip's constraint solving process is imperfect.

Then, check out this repository. In its root directory, you can install it along with test dependencies by running,

pip3 install '.[testing]'

If you like, you can run the tests from pytest with the following invocation,

python3 -m pytest

Uninstalling

If you need to uninstall Flaxformer, please run,

pip3 uninstall flaxformer

Troubleshooting

Flax deps

Flaxformer is developed in close collaboration with the Flax team. There may be bugs if your Flax version is not up to date. To install the latest version from GitHub, please run,

pip3 uninstall flax
pip3 install git+https://github.com/google/flax

Note

Flaxformer is a project maintained by a team in Google Research. It is not an official Google product.

flaxformer's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

flaxformer's Issues

How to run a simple inference on Switch base

Hi there!

First of all, awesome work on Switch transformers ๐Ÿ”ฅ
I was wondering if there is a simple example script / commands to do a simple inference using switch_base model?
Thanks !

Failed to map logical axes for target/decoder/logits...

I am getting the following error when fine-tuning longT5 model:

`
ValueError Traceback (most recent call last)
Input In [16], in <cell line: 21>()
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
19 train_using_gin()
---> 21 gin_utils.run(main_train)

File ~/Downloads/t5x/t5x/gin_utils.py:105, in run(main)
103 def run(main):
104 """Wrapper for app.run that rewrites gin args before parsing."""
--> 105 app.run(
106 main,
107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))

File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser)
310 callback()
311 try:
--> 312 _run_main(main, args)
313 except UsageError as error:
314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)

File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv)
256 sys.exit(retval)
257 else:
--> 258 sys.exit(main(argv))

Input In [15], in main_train(argv)
1 def main_train(argv: Sequence[str]):
2 """Wrapper for pdb post mortems."""
----> 3 _main(argv)

Input In [16], in _main(argv)
12 train_using_gin = gin.configurable(train)
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
---> 19 train_using_gin()

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
1604 err_str = err_str.format(name, fn_or_cls, scope_info)
-> 1605 utils.augment_exception_message_and_reraise(e, err_str)

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message)
39 proxy = ExceptionProxy()
40 ExceptionProxy.qualname = type(exception).qualname
---> 41 raise proxy.with_traceback(exception.traceback) from None

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1579 new_kwargs.update(kwargs)
1581 try:
-> 1582 return fn(*new_args, **new_kwargs)
1583 except Exception as e: # pylint: disable=broad-except
1584 err_str = ''

Input In [7], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda)
224 input_types = {
225 k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
226 }
227 init_or_restore_tick = time.time()
--> 228 train_state_initializer = utils.TrainStateInitializer(
229 optimizer_def=model.optimizer_def,
230 init_fn=model.get_initial_variables,
231 input_shapes=input_shapes,
232 input_types=input_types,
233 partitioner=partitioner)
234 # 3. From scratch using init_fn.
235 train_state = train_state_initializer.from_checkpoint_or_scratch(
236 restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)

File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
365 self._partitioner = partitioner
366 self.global_train_state_shape = jax.eval_shape(
367 initialize_train_state, rng=jax.random.PRNGKey(0))
--> 368 self.train_state_axes = partitioner.get_mesh_axes(
369 self.global_train_state_shape)
370 self._initialize_train_state = initialize_train_state
372 # Currently scanned layers require passing annotations through to the
373 # point of the scan transformation to resolve an XLA SPMD issue.
374
375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model
376 # instance from the bound method.

File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
--> 892 flat_mesh_axes = {
893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
892 flat_mesh_axes = {
--> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes)
885 return flax_partitioning.logical_to_mesh_axes(logical_axes,
886 self._logical_axis_rules)
887 except ValueError as e:
--> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e

ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
In call to configurable 'train' (<function train at 0x2b751e160>)

`

ImportError: cannot import name 'masking' from 'jax.interpreters'

Traceback (most recent call last):
File "/mnt/data1/zhangxianrong/d3pm/text/main.py", line 35, in
from text import diffusion # pylint: disable=unused-import
File "/mnt/data1/zhangxianrong/d3pm/text/diffusion.py", line 38, in
from text import models
File "/mnt/data1/zhangxianrong/d3pm/text/models.py", line 33, in
from flaxformer.architectures.t5 import t5_architecture
File "/opt/miniconda3/envs/python3.9-zxr/lib/python3.9/site-packages/flaxformer/architectures/t5/t5_architecture.py", line 35, in
from flaxformer.components import rich_attention_position_scores
File "/opt/miniconda3/envs/python3.9-zxr/lib/python3.9/site-packages/flaxformer/components/rich_attention_position_scores.py", line 30, in
from flaxformer.components import dense
File "/opt/miniconda3/envs/python3.9-zxr/lib/python3.9/site-packages/flaxformer/components/dense.py", line 23, in
from aqt.jax_legacy.jax import flax_layers as aqt_flax_layers
File "/opt/miniconda3/envs/python3.9-zxr/lib/python3.9/site-packages/aqt/jax_legacy/jax/flax_layers.py", line 26, in
from aqt.jax_legacy.jax import compute_cost_utils
File "/opt/miniconda3/envs/python3.9-zxr/lib/python3.9/site-packages/aqt/jax_legacy/jax/compute_cost_utils.py", line 27, in
from jax.interpreters import masking
ImportError: cannot import name 'masking' from 'jax.interpreters' (/opt/miniconda3/envs/python3.9-zxr/lib/python3.9/site-packages/jax/interpreters/init.p

BERT Pre-Training

Hi,

I would like to test this flaxformer library to pre-train a BERT from scratch.

What is necessary to create the pre-training data (mlm with duplication factor) on an own corpus with an own created wordpiece-based vocab.

How can the pre-training started.

I'm really excited to test it, any help is highly appreciated!

ColT5 gin files

Hello,

According to the "COLT5: Faster Long-Range Transformers with Conditional Computation" paper, it was trained using glaxformer.

Could you pelase share with us the gin files for training ColT5 ?

Thanks in advance for reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.