-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Init cache on meta device #35164
Init cache on meta device #35164
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the offloaded cache: Is there a plausible use case where the current changes would be a fix? Given that is meant to be used in a low-resource setting (= no multiple gpus), I would leave it in its original implementation if multigpu is the only purpose of the changes. In offloading we want to be precise with devices, which is the opposite of meta
.
Other than that LGTM :)
This PR has many moving pieces associated with it, so I'm leaving down a few questions to ensure we cross all i's and dot all t's:
- The original issue is long. Let's document this PR with a minimal example to reproduce the issue. It will be useful in case we need to understand why we did this;
- I'm assuming you ran tests locally, but let's write down which commands were run. (minimum: slow llama tests, slow cache tests)
- Have you benchmarked llama + static cache + compilation? If yes, leave a note in the PR header. If not, please double-check :)
- [if we want to keep the changes for the offloaded cache] Ditto for the offloaded cache, make sure it is benchmarked before and after these changes. I'm assuming existing tests would catch any correctness issue.
@@ -462,26 +462,6 @@ def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache | |||
with self.subTest(f"{attn_implementation}, static, eager"): | |||
self.assertListEqual(decoded, EXPECTED_GENERATION) | |||
|
|||
set_seed(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is related to what you wrote about offloaded caches + cuda graphs.
We test compilation in other places, so I agree it is fine to delete :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, exactly! Since we compile whenever a StaticCache
is used, the model gets compiled for cache_implementation="offloaded_cache"
And these tests were not being run to catch the break in graphs. If we want to keep cache_implementation="offloaded_cache"
working, we either disable compile specifically for this cache type or make this cache not instance of StaticCache
. I am not sure if we are planning to keep the "auto-compile the forward if static cache" feature
2a1de72
to
aed62b1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's indeed a better way to do this! Let's iterate a bit tho!
src/transformers/cache_utils.py
Outdated
if k_out.device.type == "meta": | ||
k_out = torch.zeros(*k_out.size(), device=key_states.device, dtype=key_states.dtype) | ||
v_out = torch.zeros(*v_out.size(), device=value_states.device, dtype=value_states.dtype) | ||
self.key_cache[layer_idx] = k_out | ||
self.value_cache[layer_idx] = v_out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO the else branch should do:
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
```
not sure memory wise if k_out is erased first or if you allocate more memory at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it shows more memory allocation for me. I will rewrite a bit to make code more easily inspectable
"Hello I am doing a project for my school and I am trying to make a program that will allow me to input a", | ||
"Hello I am doing a project for my school and I am trying to make a program that will allow me to use a", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test was never passing because it had no num_beams
, the hub config hasn't changed since release. So I just fixed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Thanks for fixing this !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
before merging, let's make sure we use correct torch primitives / use the simplest code as possible !
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot better, just be careful with one change and should be good!
self.conv_states: torch.Tensor = torch.zeros( | ||
config.num_hidden_layers, | ||
self.max_batch_size, | ||
self.intermediate_size, | ||
self.conv_kernel_size, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
self.ssm_states: torch.Tensor = torch.zeros( | ||
config.num_hidden_layers, | ||
self.max_batch_size, | ||
self.intermediate_size, | ||
self.ssm_state_size, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we change the shape here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we can move tensors from "meta" to "cuda" layer by layer, whenever the cache is updated for given layer. Otherwise, we can move the whole 5D cache once when the fist layer is updated, but I just wanted to be consistent with other static cache classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, sounds good in that case indeed devices can be different!
* init cache on meta device * offloaded static + enable tests * tests weren't running before :( * update * fix mamba * fix copies * update * address comments and fix tests * fix copies * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * mamba fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* init cache on meta device * offloaded static + enable tests * tests weren't running before :( * update * fix mamba * fix copies * update * address comments and fix tests * fix copies * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * mamba fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
What does this PR do?
Fixes #33147.
Initializes cache on meta device until we get the first key/values and infer the device from that. Removed
layer_device_map
as not needed anymore in most cases, because initializing cache on meta device is the default value when nodevice
is given. Offloaded Static cache still would require layer device map or one device, since it prefetches key/values in advance and we cannot infer the device as soon as we see the input.One thing to note is that Offloaded Static cache can never run on current
main
becausetorch.cuda.stream()
is not fullgraph compile compatible. Found a related issue from torch team on that: pytorch/pytorch#92804. On currentmain
I can't run even with graph breaks, and it fails after 3rd layer ontorch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream)
, but with this branch it runs if we allow graph breaks until compile cache limit is reachedSince this is the only cache that is behaving different, maybe we can make it
not instance of StaticCache
at least until the fullgraph compile is workingtest_cache_utils.py
and on Llama are green, when compared tomain
branch. Some slow llama tests are red onmain
and do not use static cache, so it wasn't caused by this PRIssue repro with 2 GPUs
Benchmark on llama + compile with
meta-llama/Llama-3.2-1B-Instruct
: