Giter VIP home page Giter VIP logo

pyramidinfer's Introduction

[ACL 2024] PyramidInfer: Pyramid KV Cache Compression for High-throughput LLM Inference

Dongjie Yang, Xiaodong Han, Yan Gao, Yao Hu, Shilin Zhang, Hai Zhao

arXiv

Updates

  • [2024-06-17] We release the code for PyramidInfer where the details can be found in here.

[WIP] This repository is still under construction. We will release the full code to evaluate the performance of PyramidInfer using the OpenCompass.

Overview

Large Language Models (LLMs) have shown remarkable comprehension abilities but face challenges in GPU memory usage during inference, hindering their scalability for real-time applications like chatbots. To accelerate inference, we store computed keys and values (KV cache) in the GPU memory. Existing methods study the KV cache compression to reduce memory by pruning the pre-computed KV cache. However, they neglect the inter-layer dependency between layers and huge memory consumption in pre-computation. To explore these deficiencies, we find that the number of crucial keys and values that influence future generations decreases layer by layer and we can extract them by the consistency in attention weights. Based on the findings, we propose PyramidInfer, a method that compresses the KV cache by layer-wise retaining crucial context. PyramidInfer saves significant memory by computing fewer keys and values without sacrificing performance. Experimental results show PyramidInfer improves 2.2x throughput compared to Accelerate with over 54% GPU memory reduction in KV cache.

Getting Started

run a demo

We recommend using the PyramidInfer with a large batch size to see more significant memory reduction and efficiency improvement.

conda create -n pyramidinfer python=3.8 -y
conda activate pyramidinfer
pip install -r requirements.txt

python simple_infer_comparison.py --model_name_or_path meta-llama/Llama-2-7b-hf

Implementation of PyramidInfer

Please check the models/modeling_llama_pyramidinfer.py to see the implementation of PyramidInfer. More details can be found in the paper.

PyramidInfer Configuration

The PyramidInfer has several hyperparameters that can be tuned to achieve better performance. The hyperparameters are defined in the configs folder, which are recommended settings for the PyramidInfer.

Prefilling Stage

  • recent_ratio: The ratio of the recent tokens not to be compressed and be used to find PvCs.
  • prefill_decay_ratio: The decay ratio of gradually reducing the context length as the layer goes deeper.
  • prefill_decay_strategy: The strategy to decay the context length. It can be linear or cosine.
  • min_context_length: The minimum context length to prevent the context length from being too short.
  • layerwise_downsample_interval: The interval to downsample the context length layer by layer. For larger models with more layers, we do not need to downsample the context length for every layer, which can reduce the additional computations of finding PvCs.
  • distance_weight: The recent tokens that are closer to the latest token have more weights to find PvCs.

Generation Stage

  • gen_decay_ratio: The decay ratio of gradually reducing the context length as the layer goes deeper in the generation stage. It is a little different from the prefilling stage, which can be checked in here.
  • gen_decay_strategy: The strategy to decay the context length in the generation stage. It can be linear or cosine.
  • exceed_length_to_compress: The threshold to compress the additional generated tokens. If the number of generated tokens exceeds this threshold, we will compress the additional generated tokens. Note: In the generation stage, we do not compress the prompt kv from the prefilling stage, but only compress the additional generated tokens in the generation stage.
  • gen_compress_ratio: If the number of additional generated tokens exceeds the threshold above, we will compress the additional generated tokens by this ratio.

Citation

@misc{yang2024pyramidinfer,
      title={PyramidInfer: Pyramid KV Cache Compression for High-throughput LLM Inference}, 
      author={Dongjie Yang and XiaoDong Han and Yan Gao and Yao Hu and Shilin Zhang and Hai Zhao},
      year={2024},
      eprint={2405.12532},
      archivePrefix={arXiv},
      primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}

pyramidinfer's People

Contributors

mutonix avatar

Stargazers

ucc117 avatar roshankarande avatar  avatar Eight Happy avatar Yunzhuang Shen avatar zxy avatar ZikaiXiao avatar  avatar 五蕴 avatar Ziqiang Liu avatar Yilong Lai avatar  avatar Xiang LIU avatar  avatar Daxiong avatar Yuzhen Mao avatar  avatar Ding Yongchao avatar Chengyuan Li avatar Dawei Zhu avatar Chengxi Guo avatar Lulzx avatar  avatar 任思宇 avatar

Watchers

 avatar Chengyuan Li avatar  avatar  avatar

pyramidinfer's Issues

window_len == 1 in generation phase

Thanks for the great work!

here, the window_len == 1, is it expected?

Besides, why evict tokens from -(1 + gen_recent_length + exceed_length_to_compress):-(1 + gen_recent_length) but not -(gen_recent_length + exceed_length_to_compress):-(gen_recent_length)? What's the point of 1 here?

LlamaForCausalLM is out of date

Great work!

Currently, I am reproducing this work. I found that the LlamaForCausalLM used in the repository is out of date, and its memory cost is much higher than the LlamaForCausalLM from Hugging Face.

Here are the results:

# original model in pyramid
Total Token Num: 14072
Max GPU Memory Per GPU (MB): 32477.294

# original model from hf
# transformers @ git+https://github.com/huggingface/transformers@2e48b3e8725326abd3e9cf82718f7d6debdd8297
Total Token Num: 14072
Max GPU Memory Per GPU (MB): 22554.607

Based on the information provided, it seems that using the LlamaForCausalLM from the Hugging Face Transformers library (at the specified commit) is more memory-efficient than the version used in the original repository. I'd suggest updating your code to use the Hugging Face version, as it appears to have a lower memory footprint.

Code for Fig 3 and Fig 5.

This is a great piece of work! I am wondering if you could provide the code for Figure 3 and Figure 5.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.