Giter VIP home page Giter VIP logo

meliad's Introduction

Meliad

This is not an officially supported Google product.

This code is provided "as-is" to the broader research community. Google does not promise to maintain or otherwise support this code in any way.

Introduction

The Meliad library is collection of models which are being developed as part of ongoing research into various architectural improvements in deep learning. The name "meliad" is the Greek word for a tree nymph; a long-term goal of this research is to design architectures that can understand recursive and compositional structures, i.e. trees.

The library currently consists of several transformer variations, which explore ways in which the popular transformer architecture can be extended to better support language modeling over long sequences.

Transformer-XL with sliding window

This model is provided as a baseline. It is similar to the Transformer-XL architecture, but uses a T5-style relative position bias. A long sequence, such as a book, is divided into segments of fixed length, e.g. 4096 tokens. The segments are processed in order, with one segment per training step.

Attention within a segment is done locally using sliding window that is typically smaller than the segment length. A causal mask ensures that each token can attend to exactly W previous tokens, where W is the window size, e.g. 512 or 1024. The complexity of attention is quadratic with respect to window size, but linear with respect to segment length, so the segment length is limited only by available device memory. Like Transformer-XL, the model caches the keys and values from the last window for use on the next training step, and thus implements truncated backpropagation through time over very long (book-length) works.

If the window and segment lengths are the same, then there is no sliding window (just the T-XL cache), and this model will behave like Transformer-XL. However, the cache is not differentiable, whereas the sliding window is, so there is some benefit to using segments that are longer than the window length. Gradients with the sliding window can potentially be backpropagated across the length of the entire segment.

Memorizing Transformer

The Memorizing Transformer equips one layer of the transformer with a large external memory that stores prior (key,value) pairs. Typical memory sizes are 32k or 64k tokens. In addition to local attention, the model can do k-nearest-neighbor lookup into the external memory, which allows it to handle long-range dependencies; the range is limited only by the size of the memory.

The external memory, like the T-XL cache, is not differentiable. Memory and the T-XL cache work well together; the memory is used for long-range lookups, while the cache is used for short-range lookups. However, memory should not be used with a sliding window, so the window and segment length should be the same.

Block-Recurrent Transformer

The Block-Recurrent Transformer equips one layer of the transformer with a recurrent cell. The cell is structured similarly to an LSTM cell, but it is several orders of magnitude larger, and operates on blocks of tokens and blocks of recurrent state vectors. Recurrence is integrated with the sliding window mechanism; the block size is the same as the window size.

Recurrence serves a similar role to external memory, but is faster. The recurrent state has a fixed capacity, but unlimited range (in theory).

Installation instructions

Create an activate a python virtual environment. (Commands given are for linux).

python -m venv my_env
source my_env/bin/activate

Install required packages into the python virtual environment. If you want to use GPUs, then Jax must be upgraded to use CUDA. Installing t5 after upgrading jax may be necessary to avoid link errors (we don't know why).

pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install t5

On Unix systems, you may need to ensure that PYTHONPATH includes the current directory. All module names are given relative to the meliad root.

export PYTHONPATH=.:$PYTHONPATH

Run a small baseline model on a synthetic test dataset.

python transformer/ht_main.py --alsologtostderr \
--gin_file=base_htrans.gin \
--gin_file=size/small_test.gin

Configuring and running the model

Meliad uses gin to configure the model. The first gin file should always be base_htrans.gin, which supplies a default configuration. Other options are specified as additional files in the configs directory. Most options are orthogonal, but in some cases the order matters; inspect the contents of the gin files to determine the correct order.

Some important options are:

  • size/medium150M.gin The 150M parameter model in the paper.
  • options/positions_t5.gin Use a T5-style relative position bias.
  • options/seq_4096.gin Use a segment length of 4096 tokens.
  • options/window_1024.gin Use a sliding window of size 1024. (The default is 512).
  • options/lr_cosine_decay.gin Cosine decay learning rate schedule.

Tasks are also defined in gin files:

  • tasks/pg19_tokens.gin Run on PG19 with the default T5 sentencepiece vocabulary.

Other important command-line options:

  • --alsologtostderr View the progress of the model.
  • --workdir=/my/work/directory For checkpoints and tensorboard.
  • --load_dir=/location/of/pretrained/model For finetuning.
  • --default_data_dir=/location/of/tfds/datasets For tensorflow datasets.

For the Memorizing Transformer:

  • size/medium150M.gin The 150M parameter model in the paper.
  • options/positions_t5.gin Use a T5-style relative position bias.
  • options/seq_512.gin Segment length of 512. (Window is 512 by default).
  • options/external_memory_32k.gin Memorizing Transformer with a memory size of 32k.

For the Block-Recurrent Transformer:

  • size/medium150M.gin The 150M parameter model in the paper.
  • options/positions_t5.gin Use a T5-style relative position bias.
  • options/seq_4096.gin Segment length of 4096. (Window is 512 by default).
  • recurrent/bias_skip.gin The fixed:skip configuration.

meliad's People

Contributors

delesley avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

meliad's Issues

Issue running models

Steps to reproduce:

Run python transformer/ht_main.py --alsologtostderr \ --gin_file=base_htrans.gin \ --gin_file=size/small_test.gin

WARNING:absl:GlobalAsyncCheckpointManager is not imported correctly. Checkpointing of GlobalDeviceArrays will not be available.To use the feature, install tensorstore. Traceback (most recent call last): File "/Users/minhnguyen/Documents/meliad/transformer/ht_main.py", line 25, in <module> from transformer import launcher File "/Users/minhnguyen/Documents/meliad/transformer/launcher.py", line 23, in <module> import training_loop File "/Users/minhnguyen/Documents/meliad/transformer/training_loop.py", line 32, in <module> import optimizer_config as opt_config File "/Users/minhnguyen/Documents/meliad/transformer/optimizer_config.py", line 21, in <module> from flax import optim ImportError: cannot import name 'optim' from 'flax' (/Users/minhnguyen/opt/anaconda3/lib/python3.9/site-packages/flax/__init__.py)

I have followed all steps to install in README. Env is MacOS.

How to run TransformerXL experiment

Hi,

Thanks for your work. I would run the Transformer XL model with the medium-150M settings.
Is removing recurrent/bias_skip.gin from the Block Recurrent Transformer settings enough to convert it to TransformerXL or do I need to modify something else?

Edit: I understand that setting segment length and window-size to the same value is required as well for TransformerXL. Where do we set the maximum size of the cache then?

Thanks a lot!

Question regarding the paper

Hi, I'm trying to implement the paper in PyTorch, but my model seems to ignore the recurrent states. Is the sliding window attention over blocks of tokens mandatory for getting it to work, or can it be trained with regular attention? It's the only thing I'm missing, I've already made the special gate initialization, but still, no luck, and I'm trying to identify what's actually missing to get the model to attend to its recurrent states.
Also, would it be possible to fine-tune an existing model to use the recurrent layer?

Install required packages on m1 Mac

clu doesn't seems to support m1 yet. Is there any workaround for this issue?

>>>pip install clu==0.0.7

Collecting clu==0.0.7
  Using cached clu-0.0.7-py3-none-any.whl (92 kB)
Collecting ml-collections
  Using cached ml_collections-0.1.1.tar.gz (77 kB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: packaging in ./my_env/lib/python3.10/site-packages (from clu==0.0.7) (21.3)
Requirement already satisfied: absl-py in ./my_env/lib/python3.10/site-packages (from clu==0.0.7) (1.3.0)
Collecting cached-property
  Using cached cached_property-1.5.2-py2.py3-none-any.whl (7.6 kB)
Requirement already satisfied: flax in ./my_env/lib/python3.10/site-packages (from clu==0.0.7) (0.5.3)
ERROR: Could not find a version that satisfies the requirement tensorflow (from clu) (from versions: none)
ERROR: No matching distribution found for tensorflow

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.