Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : support sliding window attention #3377

Closed
ggerganov opened this issue Sep 28, 2023 · 21 comments
Closed

llama : support sliding window attention #3377

ggerganov opened this issue Sep 28, 2023 · 21 comments
Labels
performance Speed related topics stale

Comments

@ggerganov
Copy link
Owner

ggerganov commented Sep 28, 2023

For more info, see: https://github.com/mistralai/mistral-src and references there in.

Also: https://arxiv.org/pdf/2310.06825v1.pdf

With #3228 it should be relatively easy to support this.

@CoruNethron
Copy link

Trying to figure out this. I found this reference to be helpful, as it seems to emplace all necessary code in one place:
https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py

Looks like Attention.forward() Py function have something in common with build_llama() // self-attention C++ part

Also, as KQ_mask already exists, it looks like, there isn't many changes needed indeed.

Probably I can manage it in a week or two, but not sure about it.

@stygmate
Copy link

Any news on that one ? 🥹
The best low memory usage models are actually based on mistral and being locked with 4096 window is very limiting for different task like document analysis.

@h3ndrik
Copy link

h3ndrik commented Nov 29, 2023

I think KoboldCPP introduced a similar feature with their recent rewrite of the "smartcontext" feature.

@HiroseKoichi
Copy link

I think KoboldCPP introduced a similar feature with their recent rewrite of the "smartcontext" feature.

Nope, they added cache shifting; sliding window attention is a different attention mechanism; they do two very different things.

@stygmate
Copy link

stygmate commented Dec 26, 2023

Maybe it's not appropriate to insist on it ( I apologize if this bothers you ) but this feature seems to be one with the most thumbs up on the dashboard. 😢

@github-actions github-actions bot added the stale label Mar 20, 2024
Copy link
Contributor

github-actions bot commented Apr 3, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Apr 3, 2024
@ggerganov
Copy link
Owner Author

It feels that since Mistral 7B from last year, there hasn't been much interest in this technique. Even later Mistral models dropped it as a feature. Taking this into account, I guess we can leave this issue closed

@ggerganov ggerganov closed this as not planned Won't fix, can't repro, duplicate, stale Apr 4, 2024
@ggerganov ggerganov moved this from Todo to Done in ggml : roadmap Apr 4, 2024
@candre23
Copy link

I guess we can leave this issue closed

@ggerganov As the new gemma 2 models use SWA (in addition to GQA, in some sort of alternating scheme?), I suggest this be revisited. As-is, gemma 2 pretty much falls apart after 4k context or so using llama.cpp.

@CoruNethron
Copy link

Probably I can manage it in a week or two, but not sure about it.

Just for the record: I've checked this again in Jan or Feb, but didn't manage. Not enough of my understanding, so, I don't expect to finish it soon.

Still it really seems, that only few additional lines of code needed.

@Galunid
Copy link
Collaborator

Galunid commented Jun 29, 2024

I think it may be worth re-evaluating this, since there's increased number of models supporting SWA (mistral, phi, gemma2 on top of my head).

@Galunid Galunid reopened this Jun 29, 2024
@Galunid Galunid removed the stale label Jun 29, 2024
@arlo-phoenix
Copy link
Contributor

arlo-phoenix commented Jun 29, 2024

I roughly implemented sliding window attention here: https://github.com/arlo-phoenix/llama.cpp/tree/gemma2
the branch is already rebased on #8197 so this should fix all gemma2 bugs.

No idea if it's correct, output isn't great yet. But it doesn't completely break like it does without it. For testing I just gave the 9b-it model the bee movie script until the I like jazz part (~7000 tokens) and it managed to generate an ending of Barry leaving through the window after some conversation (previously just repeated random stuff like You're a bee. I'm a bee. lol)

My change does fulfill the "each token can attend to at most W tokens from the previous layer" description from the Mistral paper. The mask isn't fully equal to the mistral implementation since that one does a log at the end, but I don't think that's related to SWA. It's also missing gguf parameters, just hardcoded rn (only enabled for gemma2 though)

@ngxson
Copy link
Collaborator

ngxson commented Jun 29, 2024

It seems like the sliding window technique used by gemma 2 is mostly to reduce memory usage for KV cache. The idea is to use sliding window for only some of the layer (a bit like jamba where it swap out some layer with mamba to save memory)

So the idea would be to start by initializing different KV size for each layer:

// inside llama_kv_cache_init()
for (int i = 0; i < (int) n_layer; i++) {
    struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
    int kv_size_l = (i % 2 == 0)
        ? n_embd_k_gqa*kv_size // full KV size for even layers
        : n_embd_k_gqa*4096;   // "limited" size for the rest
    ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, kv_size_l);
    ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, kv_size_l);
    ...
}

And then there're patches that need to implement for llm_build_kv_store, llm_build_kqv, etc. I haven't looked deeply into that.

I had a look on @arlo-phoenix fork but I'm not sure that's the right direction. KQ_mask is used for masking different sequences inside a batch, so probably unrelated here. (see comment below)

@arlo-phoenix
Copy link
Contributor

I did it there since I thought ggerganov suggested it with the linked #3228. And it does work. I’m pretty sure KQ_mask also isn‘t just for batch masking, but also general positional masking (the > pos check). At least made the most sense to me to do it with the mask. But yeah definitely not the most efficient way to implement it, rest goes over my head.

@ngxson
Copy link
Collaborator

ngxson commented Jun 30, 2024

@arlo-phoenix sorry I overlooked. KQ_mask has size [kv_size, n_batch], so clearly it also masks tokens in the kv, not just the batch. If you don't mind, I can propose a more clean PR based on your later today. Even if it's not the most efficient way, I believe it can be a good start.

@arlo-phoenix
Copy link
Contributor

@arlo-phoenix sorry I overlooked. KQ_mask has size [kv_size, n_batch], so clearly it also masks tokens in the kv, not just the batch. If you don't mind, I can propose a more clean PR based on your later today. Even if it's not the most efficient way, I believe it can be a good start.

@ngxson Yeah sounds good! Would be too large a change for me to do clean anyways, so thank you for doing it instead! I think you saw the hacky commit arlo-phoenix@265a8f2 since that one requires the cleaning (for the every other layer SWA, then global as you also commented above. Just copy pasted, actual difference is minimal). I only propose that the default SWA size is kept at gemma2 size since that's what most people are interested in right now (and we already did the same for the other gemma2 things) so people don't need new quants.

@matteoserva
Copy link
Contributor

matteoserva commented Jun 30, 2024

Even gemma.cpp, the reference implementation by google is giving me subpar results.

The best implementation is by @foldl in his chatllm project.
It's giving the exact same results as the aistudio version of gemma 27b.

@qnixsynapse
Copy link
Contributor

Very neat code of chatllm.. Really liked his code!!

@Faolain
Copy link

Faolain commented Jul 30, 2024

How much more difficult would it be to add a similar change for the mistral 7b architecture following the changes above for gemma2? Trying to compare with what was done here

Implementation of sliding window attention in mlc-llm

@github-actions github-actions bot added the stale label Aug 31, 2024
Copy link
Contributor

This issue was closed because it has been inactive for 14 days since being marked as stale.

@Faolain
Copy link

Faolain commented Sep 16, 2024

Seems strange for this to auto close @ggerganov , maybe can remain open for someone else to take on?

@Galunid Galunid reopened this Sep 16, 2024
@github-actions github-actions bot removed the stale label Sep 17, 2024
@github-actions github-actions bot added the stale label Oct 17, 2024
Copy link
Contributor

github-actions bot commented Nov 1, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Nov 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics stale
Projects
Status: Done
Development

No branches or pull requests