-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gemma2: add sliding window mask #8227
Conversation
Thanks for your work. I tested your PR by regenerating the gguf from hf with: The model still is unable to solve questions that are easy for aistdio gemma2. It could be that there is something missing in your implementation or there are other issues beside SWA. Example problem (anwers is 7 or 8):
I run the inference without offloading the entire model in the GPU since I don't have enough VRAM. |
@matteoserva I think it's normal that small model like this one to make on math mistake. What this PR trying to address is that gemma 2 current breaks after generating more than 4096 tokens. We could try for example, input a long document (like shakespeare) then ask it something related. |
Sorry, I was in a hurry and i didn't explain why I made that post. With this PR (and also without) the model breaks with even simple questions, well before the 4096 tokens limit. It could be related to how SWA was implemented but I'm not sure. |
@matteoserva I think the bug that you described is unrelated to this PR. The goal here is to make no change if you're generating less than 4096 tokens. Probably you should open an issue so other users can share their results (i.e. with different quantizations, sampling settings, etc) |
Can we have one or two test cases (prompt + expected outcome) that work in aistudio and should work with llama.cpp and this PR? |
@bfroemel I heard other users reported that after 4096 tokens, the generation breaks completely (gibberish output), so probably you just need to input 4096 tokens (or more, don't need to be exact), then see if it still speak english or it's drunk. (If someone know this better, feel free to correct what I said) |
I tested this PR using gemma-9b unquantized. Without this PR:
With this PR:
|
@ngxson Attached a test prompt which should be about 6k tokens. I tried it on aistudio (I have only the 27b-it model available), and I get this output:
Of course, I regenerated the output from both aistudio and llama.cpp a couple of times: aistudio always tried to answer the question in the prompt, |
Perplexity with 8192 context improves a lot.
|
Perfect, thanks @slaren @bfroemel To correct what I said earlier: without SWA, the model does not output gibberish, but repeated output (ref: #8197 (comment)). That explains what @bfroemel got from master branch. However, even with this PR, it seems like we still have issue with generation quality in general. The test with video transcription seems to be a good idea (better than shakespeare), so let's keep testing with that. |
uhm, just to correct my report: now I see the same repeated text on master branch (the thing I saw earlier was polluted by ollama. on pure llama.cpp, master I see the repeating mess).
-> ok, also focusing on the video transcript test from now on. |
Co-authored-by: Arlo Phoenix <arlo-phoenix@users.noreply.github.com>
src/llama.cpp
Outdated
if (lctx.model.arch == LLM_ARCH_GEMMA2) { | ||
GGML_ASSERT(lctx.inp_KQ_mask_SWA); | ||
GGML_ASSERT(hparams.n_sliding > 0); | ||
data = (float *) lctx.inp_KQ_mask->data; | ||
data_swa = (float *) lctx.inp_KQ_mask_SWA->data; | ||
// because layer masks are alternate for gemma 2, we only need to take first 2 layers | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified a bit.
if (lctx.model.arch == LLM_ARCH_GEMMA2) { | |
GGML_ASSERT(lctx.inp_KQ_mask_SWA); | |
GGML_ASSERT(hparams.n_sliding > 0); | |
data = (float *) lctx.inp_KQ_mask->data; | |
data_swa = (float *) lctx.inp_KQ_mask_SWA->data; | |
// because layer masks are alternate for gemma 2, we only need to take first 2 layers | |
} | |
if (lctx.inp_KQ_mask_SWA) { | |
data_swa = (float *) lctx.inp_KQ_mask_SWA->data; | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I am not mistaken, mistral uses SWA every layer. So maybe this needs to be separated to allow having only inp_KQ_mask_SWA
? Will the same implementation work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just looked at mistral reference implementation, they seem to use different mask for each layer. Link: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/cache.py
So I think my previous version (using std::vector
) can handle that. Do you think I should revert the change?
It surprises me a bit, since mistral's quality doesn't seem to degrade even it's missing SWA (or it only breaks after 4096 tokens?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have been looking at this code for a while and reviewing the mistral paper, and I think this is an implementation of the rolling buffer cache rather than sliding window attention. As far as I can tell, mistral has the same sliding window of 4096 tokens on each layer. Knowing that, it is possible to reduce the size of the KV cache to the sliding window size, but that requires some additional housekeeping so that eg. the rope still receives the absolute positions of the tokens, but the data is actually stored in the position pos % sliding_window
. But maybe I am misunderstanding something, can you point me to the specific code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should be possible. The thing I cannot figure out is how to avoid calling llama_kv_cache_find_slot()
per-layer - seems it would be a big waste to do it like this, although it would generalize to support arbitrary KV cache layer sizes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I assume the code is reference implementation so not very good quality. Having rolling buffer would be ideal for llama.cpp, but seems like too many changes. This is mostly to answer your question earlier: Will the same implementation work? Yes it works with different sliding window mask per layer, but will be waste of memory without rolling buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would the mask differ in each layer? My understanding is that the mask would be the same for all the layers, and it relies on the fact that the states in the KV cache depend on all the previous tokens to be able to access information beyond the sliding window.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked deeper into the paper, seems like I missed something.
Looking at this figure:
And the explanation:
I'd assume that the mask for each layer is shifted by the size of window - 1
, for example:
- layer 0:
0, 0, 0, 1, 1
- layer 1:
0, 0, 1, 1, 0
- layer 2:
0, 1, 1, 0, 0
- ...
But then what I don't understand is the phrase "position i of the layer k, hi, attends to all hidden states from
the previous layer with positions between i − W and i". On the surface, it seems to explain how layer 1 knows about the tokens fall outside of its window (which is in layer 0), but then what's not clear to me is how one layer can attend to the previous one.
Also looking at the HF implementation code, seems like there is no such thing. They just add same attention mask for all layers: https://github.com/huggingface/transformers/blob/e65502951593a76844e872fee9c56b805598538a/src/transformers/models/mistral/modeling_mistral.py#L354
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified a bit.
Changed in ed5496f
I think for now we can keep the implementation this way, I'll need more time to figure out how mistral actually use SWA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then what I don't understand is the phrase "position i of the layer k, hi, attends to all hidden states from the previous layer with positions between i − W and i". On the surface, it seems to explain how layer 1 knows about the tokens fall outside of its window (which is in layer 0), but then what's not clear to me is how one layer can attend to the previous one.
I think it doesn't directly "attend" to the tokens from the previous one. It just receives information about those tokens through the output of previous layer.
I am also trying to understand this concept from the past 3 days. I did not pay attention to this when Mistral v1 was released and I remember seeing that Mistral v2 removed SWA.
Does quants need to be redone again, or is this just for the inference side? |
@Dampfinchen it's recommend to re-generate, but not required. We have a default value for the added metadata, so at least existing ggufs won't break. |
The only benefit presumably being from long context imatrix measurements being more accurate? |
src/llama.cpp
Outdated
@@ -2099,6 +2101,7 @@ struct llama_hparams { | |||
uint32_t n_ff_shexp = 0; | |||
uint32_t n_expert_shared = 0; | |||
float expert_weights_scale = 0.0; | |||
uint32_t n_sliding = 0; // sliding window attention (SWA) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint32_t n_sliding = 0; // sliding window attention (SWA) | |
uint32_t n_swa = 0; // sliding window attention (SWA) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed in ed5496f
src/llama.cpp
Outdated
@@ -2661,6 +2664,9 @@ struct llama_context { | |||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] | |||
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] | |||
|
|||
// KQ mask per layer, used by sliding window attention (gemma 2) | |||
struct ggml_tensor * inp_KQ_mask_SWA; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct ggml_tensor * inp_KQ_mask_SWA; | |
struct ggml_tensor * inp_KQ_mask_swa; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed in ed5496f
src/llama.cpp
Outdated
float * data = (float *) lctx.inp_KQ_mask->data; | ||
float * data = (float *) lctx.inp_KQ_mask->data; | ||
float * data_swa = nullptr; | ||
const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the meaning of n_keep_swa
. Seems this won't work with batches of multiple sequences
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm not sure if I'm doing it correctly: It is to emulate the rolling. If we input n_tokens
then we only keep n_sliding - n_tokens
tokens in cache, so the total number of tokens for attention is n_tokens
plus n_sliding - n_tokens
equals n_sliding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to me just restricting the position delta to be less than n_swa
is enough:
diff --git a/src/llama.cpp b/src/llama.cpp
index 71b7ef62..fa207234 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -12722,7 +12722,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// may need to cut off old tokens for sliding window
if (data_swa) {
- if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
+ if (pos - lctx.kv_self.cells[i].pos >= hparams.n_sliding) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
This way, in SWA layers, the token with position 4096 does not "see" the token with position 0, but does "see" the token at position 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK thanks, that's clear for me now. I changed this code in ed5496f
I think the purpose of this PR is that right now the context size is fixed at 4K and this enables sliding window attention to get accurate results at 8K, so it's very important. |
Perplexity improved a bit with the latest change.
|
looking really good, but still seeing seemingly degraded performance/quality compared to the aistudio, Gemma2 model output :/ I am able to test the 27b-it, fp16 model locally (same temperature and top p). Maybe just expected degradation, because originally the model was bf16? Here the same perplexity test for the 27b-it, fp16:
|
@bfroemel Degraded quality is not expected - show us the exact commands that you are using, otherwise we mainly ignore such comments because there are many ways to use the examples incorrectly and in majority of cases it is a user error |
Long-term we should refactor the KV cache code to support SWA properly and with less memory. For now we can merge this so that we have Gemma2 support |
Let's merge when CI passed |
@ggerganov At first I thought it was something related to longer context and maybe a bug in the SWA implementation, but looking back at @matteoserva's test, it is as simple as that:
Locally with llama.cpp + applied PR, I get the confused answer: 18 apples, while the model on aistudio answers correctly 8 apples (also set to a temperature of 0). Gemma-2 goes through these reasoning problems step-by-step, like @matteoserva already showed, and along the way it probably confused on llama.cpp two objects (fruits and apples) and ended up with the wrong result. -> Probably best to open a new issue. |
@bfroemel Have you tried it in bf16 instead of fp16? |
Ah of course, I can try this out without offloading. /edit: grr, now I am confusing stuff. Test is still ongoing. /edit2: same bad result (18 apples). So it's not bf16. |
@bfroemel @qnixsynapse @matteoserva I moved the discussion related to generation quality to #8240 , could you copy-paste your results there? (And also move the discussion there). Thank you. |
@bfroemel You have an extra BOS token in your command. No need to add the token explicitly because it is automatically added. Use |
( @ggerganov I am feeling a bit dumb now :) Thanks for this hint! Indeed the extra BOS token significantly degrades the model performance further. With a correct prompt at least I am getting a good apple count for that particular prompt. ) |
This is a cherry-pick of ggerganov/llama.cpp#8227
* gemma2: add sliding window mask * fix data_swa uninitialized * better naming * add co-author Co-authored-by: Arlo Phoenix <arlo-phoenix@users.noreply.github.com> * replace list with single tensor * update * llama : minor styling * convert : add sanity check for query_pre_attn_scalar * fix small typo in README --------- Co-authored-by: Arlo Phoenix <arlo-phoenix@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This is a hack to support sliding window attention for gemma 2 by masking past tokens.
The goal is to make it works. While the ideal solution is to have per-layer KV cache management (with different
n_kv
per-layer), this seems to be quite challenge (ref: #3377 (comment))This implementation is mainly inspired by @arlo-phoenix 's works arlo-phoenix@265a8f2
(Test & perplexity below in the comment)
Link to working gguf: https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/tree/main