Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some improvements for KV caching #1891

Merged
merged 3 commits into from
Dec 31, 2024

Conversation

mseeger
Copy link
Contributor

@mseeger mseeger commented Dec 26, 2024

  • Ensure that KVCache buffers are only as large as config.n_query_groups
  • Shrink buffers returned by KVCache to just cover input_pos entries
  • Clean up children of classes in model.py, in particular remove forward copies

@mseeger mseeger force-pushed the kvcache_improvements4 branch from 69d6d6f to a65a96d Compare December 27, 2024 13:40
@mseeger
Copy link
Contributor Author

mseeger commented Dec 27, 2024

Can somebody help with failing tests? I don't understand why tests for Windows are failing, but pass for all other systems. And I also don't understand why the GPU tests are failing.

@Andrei-Aksionov
Copy link
Contributor

Hello @mseeger

Thank you for another PR.

Can somebody help with failing tests? I don't understand why tests for Windows are failing

yeah, there is always something with Windows.
Perhaps we are lucky and a simple torch bump from #1893 might help 🤷

And I also don't understand why the GPU tests are failing.

I'll check it tomorrow.

litgpt/adapter.py Show resolved Hide resolved
litgpt/adapter.py Outdated Show resolved Hide resolved
litgpt/adapter.py Show resolved Hide resolved
litgpt/adapter_v2.py Show resolved Hide resolved
litgpt/model.py Outdated Show resolved Hide resolved
litgpt/model.py Outdated Show resolved Hide resolved
litgpt/model.py Show resolved Hide resolved
litgpt/model.py Outdated Show resolved Hide resolved
litgpt/model.py Outdated Show resolved Hide resolved
litgpt/model.py Show resolved Hide resolved
@Andrei-Aksionov
Copy link
Contributor

Hello @mseeger

It's quite a PR 🫠 🙂.
I left a couple of comments.
Overall it looks really good. I like how PEFT variants now look like, allows focusing on differences easily now 😊.

(I'll take a look why GPU tests are failing later.)

@mseeger mseeger force-pushed the kvcache_improvements4 branch from 7f2c2ce to 3226323 Compare December 31, 2024 10:21
@mseeger
Copy link
Contributor Author

mseeger commented Dec 31, 2024

OK, I reacted to comments. I also did a small change in lora.py, where the CausalSelfAttention.__init__ was still copy and paste, now it calls the superclass init.

litgpt/model.py Outdated Show resolved Hide resolved
@Andrei-Aksionov
Copy link
Contributor

OK, I reacted to comments. I also did a small change in lora.py, where the CausalSelfAttention.init was still copy and paste, now it calls the superclass init.

Cool, we are almost there 🙂.
There are a couple of unresolved comments left.

On my side I'll try to find and fix issues with failing GPU tests, hopefully this year 😃.

- Shrink buffers returned by KVCache to just cover input_pos entries
- Refactor child classes of model.py classes to avoid copy and paste
@mseeger mseeger force-pushed the kvcache_improvements4 branch from 3226323 to 3702b03 Compare December 31, 2024 12:50
@Andrei-Aksionov
Copy link
Contributor

Overall, the issue with GPU+Thunder is something specific to the latter.
I'll merge the PR as is and later discuss it with Thunder team.

Thanks again for the PR (and for the patience 😊).

Happy New Year! 🚀

@Andrei-Aksionov Andrei-Aksionov merged commit 17a58df into Lightning-AI:main Dec 31, 2024
8 of 9 checks passed
@t-vi
Copy link
Contributor

t-vi commented Jan 20, 2025

@mseeger unfortunately the "Shrink buffers returned by KVCache to just cover input_pos entries" does not work as intended: Introducing data-dependent control flow (i.e. here making a tensor size out of a tensor value) typically breaks compilation just as much as using CPU integers.

The other aspect here is that in my experience, the performance impact of this has been extremely limited if the attention implementation works reasonably, so I would have a tendency to revert this part of the change.

@mseeger
Copy link
Contributor Author

mseeger commented Jan 20, 2025

Hello, my original PR did not make a tensor size out of a tensor value. In fact, this is why input_pos_maxp1 must be an extra argument (of int type) instead of just computing the max of input_pos.

But sure, if you like to revert this (the other changes I think are useful). I am already working on something that would make this obsolete.

@mseeger
Copy link
Contributor Author

mseeger commented Jan 20, 2025

My intent is that LitGPT supports modern KV caching and sparse attention. That would be something that for example Hugging Face does not do properly. They support some KV caching, but only for a few models.

@mseeger
Copy link
Contributor Author

mseeger commented Jan 20, 2025

With what I am preparing, KV caches would receive a proper abstract definition (with a very simple implementation for what is currently given, namely the exact KV cache), and this input_pos_maxp1 would be removed.

@t-vi
Copy link
Contributor

t-vi commented Jan 20, 2025

if input_pos_maxp1 > self.max_seq_length:
    raise ValueError(f"Positions in 'input_pos' must be in [0,{self.max_seq_length})")
mask = mask[..., :input_pos_maxp1]

So in these lines, if input_pos_maxp1 is a tensor, the if is data-dependent control-flow because it depends on the value of the tensor. I'm not sure if the line is per se useful.
The second line, using the tensor for slicing, makes mask of size min(old_size, input_pos_maxp1) which depends on the value of the tensor.

I know these things are subtle. In #1912 we specifically drop the input_pos_maxp1 this from the generate when dealing with ThunderModules.
What we found useful (and I would prefer over the current implementation) is instead of slicing the tensors, we pass the input_pos_maxp1 and the full tensors to the attention function and then that decide not to look at later parts of the tensors. This may also be advantageous in case of "funny sizes" (not divisible by whatever power of two the attention kernel wants for loading things into memory) because it seems not so great style to access tensors out of bounds even if you know they are slices of a larger ones. We could have a default attention function that does the slicing in eager mode and then compilers can override it as needed.

@t-vi
Copy link
Contributor

t-vi commented Jan 20, 2025

(And needless to add, I do greatly appreciate your work on KVCaches in LitGPT.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants