Giter VIP home page Giter VIP logo

longheads's Introduction

LongHeads: Multi-Head Attention is Secretly a Long Context Processor [paper]

News

  • [2024/03/25] We successfully offload the KV cache to CPU during inference. This implementation significantly reduces memory usage and enables LLaMA-2-7b to achieve 100% accuracy with 128k context on passkey retrieval task!!!
  • [2024/03/25] We release the evaluation code for the passkey retrieval task.
  • [2024/03/24] We release the example code for LongHeads.

Overview

LongHeads is a training-free framework for extending the context window of large language models (LLMs) to more than 32x times their original pre-training length. LongHeads works efficiently in linear time, fits seamlessly with many LLMs that use relative positional encoding and can be integrated with popular extrapolation methods such as Positional Interpolation (PI) and NTK-Dynamic RoPE.

schemes

🚀Quick Start

Environment Setup

pip install -r requirements.txt
# We use flash-attn==2.3.6
pip install flash-attn --no-build-isolation (FlashAttention >= 2.3.6)

Load model with LongHeads

# load longheads model
from modeling_longheads import LlamaForCausalLM
longheads_config = {
    # chunk size setting for longheads
    'window_size':256,
    # the attention window length of longheads (atten_length should be smaller to model's pretrained length)
    'atten_length':4096,
    # during encoding phrase, we use this praram to begin streamingly encoding long context with chunk selection strategy
    'begin_selective_length':4096,
    # whether offload KV cache to cpu memory, if True longheads can generate to 128k+ context length
    'cpu_offload':False,
    # whether use batch_encoding for encoding phrase acceleration, if True more memory will be needed
    'batch_encoding':False,
    # the hyper param for batch encoding
    'encoding_batch_size':128,
}
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, **longheads_config)

Run Inference Example

python example.py

Passkey Retrieval

LongHeads 128k

We successfully extend LLaMA-2-7b to 128k with LongHeads without additional training and achieve 100% accuracy with 128k context on passkey retrieval task! After offloading the KV cache to CPU, peak GPU memory usage is 26.51GB and 44.48 GB when inference with 64k and 128k context.

passkey_128k

Evaluation

bash /passkey_retrieval/passkey_retrieval_script.sh

TODOs

We will release the code in the following order, please stay tuned!

  • Release core code of LongHeads.
  • Release example code for usage.
  • Release passkey retrieval evaluation code.
  • Release code of LongHeads with efficient implementation.

Citation

If you find LongHeads useful or relevant to your project and research, please kindly cite our paper:

@misc{lu2024longheads,
      title={LongHeads: Multi-Head Attention is Secretly a Long Context Processor}, 
      author={Yi Lu and Xin Zhou and Wei He and Jun Zhao and Tao Ji and Tao Gui and Qi Zhang and Xuanjing Huang},
      year={2024},
      eprint={2402.10685},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

longheads's People

Contributors

lululuyi avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.