Giter VIP home page Giter VIP logo

snapkv's People

Contributors

ctlllll avatar leeyeehoo avatar wendyh1108 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

snapkv's Issues

Question on GQA implementation

In GQA, only one copy of kv cache will be saved for each group, but snapKV saves kv cache with num_key_value_heads * num_key_value_groups heads. Indeed in kv cache eviction, the choice might be different for kv cache in the same group, but it increases memory cost by num_key_value_groups. Is there a way we can solve this?

Grouped query attention implementation

Thank you for your nice work and sharing code. Grouped query attention is used in Mistral and Mixtral models. However, I found the implementation in snapkv_utils.py is for multi-head attention, it may not be correct for grouped query attention.

Questions on paper and code [prompting for mistral, positional index, minor errors & questions in paper]

Hello :)
Thank you for the excellent work and for sharing your code. I've learned a lot and have a few questions about the paper and settings:

  • In Figures 2 and 3, what specifically do "prompt" and "context" represent? My guess is that "prompt" refers to the entire input sequence length, and "context" includes specific instructions. Should their labels be switched?

  • Could you share the specific prompt details applied in the Mistral experiment for measuring LongBench performance? Using the default LongBench settings, I observed lower performance overall, particularly in Qasper:

    • For Mistral-v2: Full: 28.92, SnapKV 2048: 26.43, 4096: 28.42 (reported: 33.06/32.47/33.36 respectively).
    • Intuitively, I think that sending the task-specific prompt (ex. You are given a scientific article and a question. Answer the question....) from LongBench to the end of the input sequence, so it falls within the window range, might improve performance. Was there any such modification?
  • Following the SnapKV methodology, I expect the KV cache size to always be bounded by the max_capacity_prompt. Yet, why does an OOM error occur when exceeding a certain length? (131K at Sec 5.1) Could it be due to recalculating the attention weights in Line 9 of Listing 1?

Additionally, there seems to be a minor error in Figure 7 where both the top and bottom plots are labeled as "without Pooling." It might be less confusing to label the bottom plot as "with Pooling."

Thank you for any insights you can provide. I really appreciate the motivation and methodology behind your work!

observation window size and consistency between layers

Hello :)

Thank you for the brilliant work and for sharing your code. After reading the paper and reviewing the related code, I have the following questions:

  1. Have you conducted experiments related to the observation window size (e.g., sizes ranging from 1 to 64)? How does this impact the hit rates and overall model performance?
  2. In the "layer-wise average hit rate" experiment, the hit rate of the middle layers is significantly lower than that of the shallow and deep layers. Do you know the reason for this?

Thank you for your excellent paper!

Question on H2O experiment reproduction

Thanks for your excellent work!

As stated in the paper Table 1: "Performance comparison of SnapKV and H2O across various LLMs on LongBench", could you provide the scripts/codes for reproducing H2O evaluations on LongBench?

Thanks in advance.

What prompt was used in Needle in a Haystack test?

I try to reproduce needle test with LWM-Text-Chat-1M but the model just refuse to answer. I have tried following prompts in Needle test and the model just generate </s>

<s>[INST] <<SYS>>
You are a helpful AI bot that answers questions for a user. Keep your response short and direct
<</SYS>>
{ context }

{retrieval_question} Don't give information outside the document or repeat your findings
[/INST]

and

<s>[INST] <<SYS>>
You are a helpful AI bot that answers questions for a user. Keep your response short and direct
<</SYS>>
{ context }

{retrieval_question} Don't give information outside the document or repeat your findings
[/INST]</s>

Can't not run longbench!

Here is my env. The version of transfomers is meet the requirements in monkeypatch.py

torch==2.2.0
transfomers==4.37.0

The traceback are as follows:

traceback

>> python pred_snap.py --model llama2-7b-chat-4k --compress_args_path ablation_c1024_w32_k7_maxpool.json

Traceback (most recent call last):
File "experiments/LongBench/pred_snap.py", line 321, in
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "experiments/LongBench/pred_snap.py", line 132, in get_pred_single_gpu
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 1474, in generate
return self.greedy_search(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 2335, in greedy_search
outputs = self(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
outputs = self.model(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1035, in forward
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 398, in _prepare_4d_causal_attention_mask_for_sdpa
expanded_4d_mask = attn_mask_converter.to_4d(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 137, in to_4d
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
RuntimeError: The size of tensor a (3509) must match the size of tensor b (7017) at non-singleton dimension 3

I think the reason would be DynamicCache.get_usable_length conflict with the getting-casual-mask function _prepare_4d_causal_attention_mask_for_sdpa.

I would like to know how can I quick fix this. Thx :)

It seems that snapkv need to be able to do "prefill" at least once before the prompt can be compressed.

snapkv need a full len q, k matmul before its first self-attention, which is a $O(n^2)$ space complexity. So is snapkv need to be able to do "prefill" at least once before the prompt can be compressed?

after that it can save memory footprint during decoding phase.

   def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):
        # check if prefix phase
        assert key_states.shape[-2] == query_states.shape[-2]
        bsz, num_heads, q_len, head_dim = query_states.shape
        if q_len < self.max_capacity_prompt:
            return key_states, value_states
        else:
            attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)
            

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.