-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache: Static cache as a standalone object #30476
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LFGTM
src/transformers/cache_utils.py
Outdated
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" | ||
"""Returns the sequence length of the cached states that were seen by the model.""" | ||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
# limit the check to the first batch member and head dimension. | ||
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after | ||
# https://github.com/pytorch/pytorch/issues/120248 is fixed | ||
return (self.key_cache[0, 0].any(dim=-1)).sum() | ||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will remove this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's slow and not reliable, generate should never use it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(needs deprecation cycle and it's easer to do after we isolate the prefill stage, I'm going to leave it off this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine by me to deprecate
raise ValueError( | ||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " | ||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be compatible if we slice the q k v efficiently, but that's too much trouble
Taking this on to finish! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
If you use the memory efficient kernel it's 20% slower. That's what we use by default |
https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03 for the benchmarks |
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | ||
# to infer the attention mask. | ||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
using_static_cache = isinstance(past_key_values, StaticCache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understand it, once the StaticCache is initialized, there is no need to pass it in past_key_values
argument. That's why additional condition is necessary. Suggestion:
using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@poedator This PR changes precisely the assumption you wrote: we will always need to pass the cache, after this PR it is an object that does NOT live inside the model.
This change will make the transformers' team work easier 🤗
} | ||
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test | ||
# was changed to have a cache of 53 tokens (as opposed to 4096). | ||
EXPECTED_TEXT_COMPLETION = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as here: #30437 (comment) please make sure to validate these tests on the T4 and A10 runners 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was indeed a mismatch on T4 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolute great work
src/transformers/cache_utils.py
Outdated
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" | ||
"""Returns the sequence length of the cached states that were seen by the model.""" | ||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
# limit the check to the first batch member and head dimension. | ||
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after | ||
# https://github.com/pytorch/pytorch/issues/120248 is fixed | ||
return (self.key_cache[0, 0].any(dim=-1)).sum() | ||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine by me to deprecate
src/transformers/cache_utils.py
Outdated
self.key_cache[layer_idx] *= 0.0 | ||
self.value_cache[layer_idx] *= 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.key_cache[layer_idx] *= 0.0 | |
self.value_cache[layer_idx] *= 0.0 | |
self.key_cache[layer_idx] = 0.0 | |
self.value_cache[layer_idx] = 0.0 |
might be faster?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setting to a new tensor produces a graph break 💔 (I'm assuming you meant self.key_cache[layer_idx] = torch.zeros(...)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No no, I think just filling them with zeros should work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would result in TypeError: 'float' object is not subscriptable
when indexing the cache :D
But filling with zeros with tensor.zero_()
works 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok 👍🏻 let's go with that then!
|
||
if cache_position is None: | ||
if isinstance(past_key_values, StaticCache): | ||
raise ValueError("cache_position is a required argument when using StaticCache.") | ||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arf alright, let's add maybe a TODO? as we won't be initializing with get_seq_length later on!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a todo on get_seq_length
👍
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in | ||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | ||
# to infer the attention mask. | ||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
using_static_cache = isinstance(past_key_values, StaticCache) | ||
if self.config._attn_implementation == "sdpa" and not using_static_cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is new, and since we pass cahce position, let's use cache_position[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed in theory, can't do in practice: breaks torch.fx tests 💔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah thought so
if using_static_cache: | ||
target_length = past_key_values.get_max_length() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we always use get_max_length()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_max_length()
is None
in the dynamic caches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be seq_length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but alright
@@ -684,15 +683,25 @@ def test_model_13b_greedy_generation(self): | |||
@require_torch_gpu | |||
@require_read_token | |||
def test_compile_static_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should require torch > 2.2
# Static Cache + compile | ||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
) | ||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good thanks
* 4d mask fixes * Update custom 4D mask logic * test moved to mixin * extra tests 4d mask * upd 4d mask and StaticCache handling * added Mask4DTestHard to mistral tests * post-rebase fixes * test fixes for StaticCache * make fix-copies * upd 1 after #30476 * fix common tests * rm elif attention_mask.dim() == 4: * tests combined, fixed, mixtral supported * bigbird style chg reverted * rm if attention_mask.dim() == 2 * modeling_llama formatting chg --------- Co-authored-by: Joao Gante <joao@huggingface.co>
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 | ||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and 2.2.1 works as well
What does this PR do?
Replaces the current format of
StaticCache
[an object living inside a model, containing the cache for one layer] with a standalone object matching the otherCache
objects. The new format preserves the existingtorch.compile
capabilities while being easier to manipulate, especially outside a model.In the process, removes all traces of the previous format across all models, tests, and docs.
Fixes #30417 (In place of #30437)
Fixes #30351
Benchmarks
(RTX3090, tiny-llama model,
torch==2.4.0.dev20240424+cu121
)Benchmark code
commit ==
![Screenshot 2024-04-25 at 10 05 23](https://private-user-images.githubusercontent.com/12240844/325544681-99cbd4cc-62bf-4be7-a6d3-42b69a10d4c7.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwMzkzNjQsIm5iZiI6MTczOTAzOTA2NCwicGF0aCI6Ii8xMjI0MDg0NC8zMjU1NDQ2ODEtOTljYmQ0Y2MtNjJiZi00YmU3LWE2ZDMtNDJiNjlhMTBkNGM3LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDglMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA4VDE4MjQyNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWYzY2ViMGZmODBkYWM2ZjdhYzZmY2ZkMzM5MzkzYmRiNjM2N2Q0Mjk0YjE0MzVjNjFiNGUxMGZlNmQ0MjIxMjUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.uM-t3MfiJ-0lpBtfzCZloPVdzQ8xzdGOvGedUAbSAAA)
14b19c4ef365f90797e07b2a20caaaaf3901b2d2
v4.39.0
![Screenshot 2024-04-25 at 10 05 48](https://private-user-images.githubusercontent.com/12240844/325544793-6b7dbac7-5e2b-4bef-8456-4c8f6683692e.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwMzkzNjQsIm5iZiI6MTczOTAzOTA2NCwicGF0aCI6Ii8xMjI0MDg0NC8zMjU1NDQ3OTMtNmI3ZGJhYzctNWUyYi00YmVmLTg0NTYtNGM4ZjY2ODM2OTJlLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDglMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA4VDE4MjQyNFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWE5YmJiZjExYzc4YWE5YzhkYmUzMzdkYWFiYzIxZTBmODgxZjRkMTRiMjAyNTU2MTAyYmQ0YjVmYzBlMTJkMjAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.x4Z9hy4INaOh00DrGuCnSAG6YGooJ9fNMB0woeur9Ts)