snapkv's People
snapkv's Issues
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) RuntimeError: The size of tensor a (3509) must match the size of tensor b (7017) at non-singleton dimension 3
I use transformers==4.27.0 ,why has this error?
RuntimeError: The size of tensor a (3509) must match the size of tensor b (7017) at non-singleton dimension 3
Question on GQA implementation
In GQA, only one copy of kv cache will be saved for each group, but snapKV saves kv cache with num_key_value_heads * num_key_value_groups
heads. Indeed in kv cache eviction, the choice might be different for kv cache in the same group, but it increases memory cost by num_key_value_groups
. Is there a way we can solve this?
Can I use the SnapKV without the flash-attention ?
Thank you for making the SnapKV code public.
I would like to ask if SnapKV can be used without flash-attention
?
Could you provide the code for visualization the Hit Rate?
Could you provide the code for visualization the Hit Rate like fig 2 & 3?
Can snapkv compress kv in case different user questions are posed towards the same context?
Say there is a long document, then two users ask two different questions based on the document. These two questions are no way similar, targeting on different part of the document. In this case, can snapkv compress the context robustly?
Grouped query attention implementation
Thank you for your nice work and sharing code. Grouped query attention is used in Mistral and Mixtral models. However, I found the implementation in snapkv_utils.py
is for multi-head attention, it may not be correct for grouped query attention.
maybe a bug in `update_kv` function
Line 50 in ea655b1
In update_kv
function, instead of using the function's arguments attention_mask
, this variable is overridden.
why only decode do compress?
Questions on paper and code [prompting for mistral, positional index, minor errors & questions in paper]
Hello :)
Thank you for the excellent work and for sharing your code. I've learned a lot and have a few questions about the paper and settings:
-
In Figures 2 and 3, what specifically do "prompt" and "context" represent? My guess is that "prompt" refers to the entire input sequence length, and "context" includes specific instructions. Should their labels be switched?
-
Could you share the specific prompt details applied in the Mistral experiment for measuring LongBench performance? Using the default LongBench settings, I observed lower performance overall, particularly in Qasper:
- For Mistral-v2: Full: 28.92, SnapKV 2048: 26.43, 4096: 28.42 (reported: 33.06/32.47/33.36 respectively).
- Intuitively, I think that sending the task-specific prompt (ex. You are given a scientific article and a question. Answer the question....) from LongBench to the end of the input sequence, so it falls within the window range, might improve performance. Was there any such modification?
-
Following the SnapKV methodology, I expect the KV cache size to always be bounded by the
max_capacity_prompt
. Yet, why does an OOM error occur when exceeding a certain length? (131K at Sec 5.1) Could it be due to recalculating the attention weights in Line 9 of Listing 1?
Additionally, there seems to be a minor error in Figure 7 where both the top and bottom plots are labeled as "without Pooling." It might be less confusing to label the bottom plot as "with Pooling."
Thank you for any insights you can provide. I really appreciate the motivation and methodology behind your work!
observation window size and consistency between layers
Hello :)
Thank you for the brilliant work and for sharing your code. After reading the paper and reviewing the related code, I have the following questions:
- Have you conducted experiments related to the observation window size (e.g., sizes ranging from 1 to 64)? How does this impact the hit rates and overall model performance?
- In the "layer-wise average hit rate" experiment, the hit rate of the middle layers is significantly lower than that of the shallow and deep layers. Do you know the reason for this?
Thank you for your excellent paper!
Question on H2O experiment reproduction
Thanks for your excellent work!
As stated in the paper Table 1: "Performance comparison of SnapKV and H2O across various LLMs on LongBench", could you provide the scripts/codes for reproducing H2O evaluations on LongBench?
Thanks in advance.
What prompt was used in Needle in a Haystack test?
I try to reproduce needle test with LWM-Text-Chat-1M but the model just refuse to answer. I have tried following prompts in Needle test and the model just generate </s>
<s>[INST] <<SYS>>
You are a helpful AI bot that answers questions for a user. Keep your response short and direct
<</SYS>>
{ context }
{retrieval_question} Don't give information outside the document or repeat your findings
[/INST]
and
<s>[INST] <<SYS>>
You are a helpful AI bot that answers questions for a user. Keep your response short and direct
<</SYS>>
{ context }
{retrieval_question} Don't give information outside the document or repeat your findings
[/INST]</s>
Can't not run longbench!
Here is my env. The version of transfomers
is meet the requirements in monkeypatch.py
torch==2.2.0
transfomers==4.37.0
The traceback are as follows:
traceback
>> python pred_snap.py --model llama2-7b-chat-4k --compress_args_path ablation_c1024_w32_k7_maxpool.json
Traceback (most recent call last):
File "experiments/LongBench/pred_snap.py", line 321, in
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "experiments/LongBench/pred_snap.py", line 132, in get_pred_single_gpu
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 1474, in generate
return self.greedy_search(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 2335, in greedy_search
outputs = self(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
outputs = self.model(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1035, in forward
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 398, in _prepare_4d_causal_attention_mask_for_sdpa
expanded_4d_mask = attn_mask_converter.to_4d(
File "/data1/ss/anaconda3/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 137, in to_4d
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
RuntimeError: The size of tensor a (3509) must match the size of tensor b (7017) at non-singleton dimension 3
I think the reason would be DynamicCache.get_usable_length
conflict with the getting-casual-mask function _prepare_4d_causal_attention_mask_for_sdpa
.
I would like to know how can I quick fix this. Thx :)
It seems that snapkv need to be able to do "prefill" at least once before the prompt can be compressed.
snapkv need a full len q, k matmul before its first self-attention, which is a
after that it can save memory footprint during decoding phase.
def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):
# check if prefix phase
assert key_states.shape[-2] == query_states.shape[-2]
bsz, num_heads, q_len, head_dim = query_states.shape
if q_len < self.max_capacity_prompt:
return key_states, value_states
else:
attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)
The effect of Clustering via Pooling may be greater?
Just a guess.
What will happen if H2O also uses Clustering via Pooling when comparing? It seems that Clustering via Pooling can improve the effectiveness of such drop token methods.
Only kv is compressed. Is the size of Q and K inconsistent when attention is calculated?
Only kv is compressed. Is the size of Q and K inconsistent when attention is calculated?
Closed issue
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.