Giter VIP home page Giter VIP logo

landmark-attention's Introduction

Landmark Attention

This repository contains the implementation of landmark attention as described in our paper:

Landmark Attention: Random-Access Infinite Context Length for Transformers
Amirkeivan Mohtashami, Martin Jaggi
Preprint: https://arxiv.org/abs/2305.16300

Training

For training, the landmark tokens are added during data preparation. The following command is an example of training a model on PG19 with landmark tokens added every 50 tokens:

python main.py \
    --config_format rotary \
    --model landmark \
    --n_embd 1024 \
    --n_head 8 \
    --n_layer 12 \
    --batch_size 16 \
    --sequence_length 512 \
    --acc_steps 8 \
    --wandb_project memory-llm \
    --dataset pg19 \
    --iterations 240000 \
    --dropout 0.0 \
    --positional_encoder rotary \
    --softmax_func mem_opt \
    --mem_freq 50 \
    --wandb \
    --save_checkpoint_freq 20000

To run on multi-GPUs use torchrun (e.g. torchrun --nproc_per_node=4) and pass --distributed_backend nccl to main.py script. We suggest first running the script until the training starts on a single GPU before switching to multi-GPU settings. This is because the first node will have to perform the initialization of the data which can take a long time leading to a timeout on the synchronization in multi-GPU settings. However, once the initialization is performed once, the result is stored on the disk so the next runs will be quick.

You will need to initialize the dataset before running the training script. For instructions, use the prepare.py script in the corresponding dataset folder located inside data/.

Inference

The code supports inference in various settings. To perform standard evaluation, disable cache and use the same chunk size (specified using --mid_length flag) as the evaluation length (specified by --eval_seq_length). Using landmarks is possible when using mem_cache. The script eval_cmd_generator.py can be used to generate a bash script containining commands to perform evaluations corresponding to Tables 1 and 2 of the paper. The path of the output models need to be updated inside the script.

LLaMA fine-tuning

The code for fine-tuning LLaMA and testing the final model is available as a standalone project in the sub-directory "llama". An example for running the fine tuning (from inside the sub-directory) is:

torchrun --nproc_per_node=8  train.py  \
    --model_name_or_path /llama_weights/7B_hf/ \
    --bf16 True \
    --output_dir /llama-redpajama-mem-15000-with-mem/  \
    --cache_dir /hf-cache/ \
    --num_train_epochs 1  \
    --per_device_train_batch_size 2     \
    --per_device_eval_batch_size 2     \
    --gradient_accumulation_steps 8     \
    --evaluation_strategy "no"     \
    --save_strategy "steps"     \
    --save_steps 2000     \
    --save_total_limit 2     \
    --learning_rate 2e-5     \
    --weight_decay 0.1     \
    --warmup_ratio 0.03     \
    --lr_scheduler_type "cosine"     \
    --logging_steps 1     \
    --fsdp "full_shard auto_wrap"     \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer'     \
    --tf32 True \
    --max_steps 15000

In the above example, LLaMA wieghts (converted to huggingface format) should be in /llama_weights/7B_hf/.

Fine-tuned Weights

We have released the weight diff between the original LLaMA 7B and the same model fine-tuned for 15000 steps on RedPajama dataset with landmark attention here. You may use the weight_diff.py script to recover the weights:

python weight_diff.py recover --path_raw <path_to_original_llama7b_weights> --path_diff <path_to_weight_diff> --path_tuned <path_to_store_recovered_weights>

For an example of how to perform inference using landmarks, look at run_test.py.

Naming

During the development of this project, we made the decision to update the names of certain components. However, as this decision was made later in the project timeline, you may encounter references to the old names within the code (e.g. mem instead of landmark). We are working to address this issue.

landmark-attention's People

Contributors

mkrima avatar alpindale avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.