-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : support sliding window attention #3377
Comments
Trying to figure out this. I found this reference to be helpful, as it seems to emplace all necessary code in one place: Looks like Also, as Probably I can manage it in a week or two, but not sure about it. |
Any news on that one ? 🥹 |
I think KoboldCPP introduced a similar feature with their recent rewrite of the "smartcontext" feature. |
Nope, they added cache shifting; sliding window attention is a different attention mechanism; they do two very different things. |
Maybe it's not appropriate to insist on it ( I apologize if this bothers you ) but this feature seems to be one with the most thumbs up on the dashboard. 😢 |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
It feels that since Mistral 7B from last year, there hasn't been much interest in this technique. Even later Mistral models dropped it as a feature. Taking this into account, I guess we can leave this issue closed |
@ggerganov As the new gemma 2 models use SWA (in addition to GQA, in some sort of alternating scheme?), I suggest this be revisited. As-is, gemma 2 pretty much falls apart after 4k context or so using llama.cpp. |
Just for the record: I've checked this again in Jan or Feb, but didn't manage. Not enough of my understanding, so, I don't expect to finish it soon. Still it really seems, that only few additional lines of code needed. |
I think it may be worth re-evaluating this, since there's increased number of models supporting SWA (mistral, phi, gemma2 on top of my head). |
I roughly implemented sliding window attention here: https://github.com/arlo-phoenix/llama.cpp/tree/gemma2 No idea if it's correct, output isn't great yet. But it doesn't completely break like it does without it. For testing I just gave the 9b-it model the bee movie script until the I like jazz part (~7000 tokens) and it managed to generate an ending of Barry leaving through the window after some conversation (previously just repeated random stuff like You're a bee. I'm a bee. lol) My change does fulfill the "each token can attend to at most W tokens from the previous layer" description from the Mistral paper. The mask isn't fully equal to the mistral implementation since that one does a log at the end, but I don't think that's related to SWA. It's also missing gguf parameters, just hardcoded rn (only enabled for gemma2 though) |
It seems like the sliding window technique used by gemma 2 is mostly to reduce memory usage for KV cache. The idea is to use sliding window for only some of the layer (a bit like jamba where it swap out some layer with mamba to save memory) So the idea would be to start by initializing different KV size for each layer: // inside llama_kv_cache_init()
for (int i = 0; i < (int) n_layer; i++) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
int kv_size_l = (i % 2 == 0)
? n_embd_k_gqa*kv_size // full KV size for even layers
: n_embd_k_gqa*4096; // "limited" size for the rest
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, kv_size_l);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, kv_size_l);
...
} And then there're patches that need to implement for
|
I did it there since I thought ggerganov suggested it with the linked #3228. And it does work. I’m pretty sure KQ_mask also isn‘t just for batch masking, but also general positional masking (the > pos check). At least made the most sense to me to do it with the mask. But yeah definitely not the most efficient way to implement it, rest goes over my head. |
@arlo-phoenix sorry I overlooked. KQ_mask has size |
@ngxson Yeah sounds good! Would be too large a change for me to do clean anyways, so thank you for doing it instead! I think you saw the hacky commit arlo-phoenix@265a8f2 since that one requires the cleaning (for the every other layer SWA, then global as you also commented above. Just copy pasted, actual difference is minimal). I only propose that the default SWA size is kept at gemma2 size since that's what most people are interested in right now (and we already did the same for the other gemma2 things) so people don't need new quants. |
Very neat code of chatllm.. Really liked his code!! |
How much more difficult would it be to add a similar change for the mistral 7b architecture following the changes above for gemma2? Trying to compare with what was done here
Implementation of sliding window attention in mlc-llm |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
Seems strange for this to auto close @ggerganov , maybe can remain open for someone else to take on? |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
For more info, see: https://github.com/mistralai/mistral-src and references there in.
Also: https://arxiv.org/pdf/2310.06825v1.pdf
With #3228 it should be relatively easy to support this.
The text was updated successfully, but these errors were encountered: