Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init cache on meta device #35164

Merged
merged 14 commits into from
Jan 22, 2025
Merged

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Dec 9, 2024

What does this PR do?

Fixes #33147.

Initializes cache on meta device until we get the first key/values and infer the device from that. Removed layer_device_map as not needed anymore in most cases, because initializing cache on meta device is the default value when no device is given. Offloaded Static cache still would require layer device map or one device, since it prefetches key/values in advance and we cannot infer the device as soon as we see the input.

One thing to note is that Offloaded Static cache can never run on current main because torch.cuda.stream() is not fullgraph compile compatible. Found a related issue from torch team on that: pytorch/pytorch#92804. On current main I can't run even with graph breaks, and it fails after 3rd layer on torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream), but with this branch it runs if we allow graph breaks until compile cache limit is reached

Since this is the only cache that is behaving different, maybe we can make it not instance of StaticCache at least until the fullgraph compile is working


  • Slow tests on test_cache_utils.py and on Llama are green, when compared to main branch. Some slow llama tests are red on main and do not use static cache, so it wasn't caused by this PR

Issue repro with 2 GPUs

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = 'google/gemma-2-2b-it'
tokenizer = AutoTokenizer.from_pretrained(model_id)

device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
num_hidden_layers = 26
for i in range(num_hidden_layers):
    device_map[f"model.layers.{i}"] = 0 if i < 13 else 1

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="bfloat16",
    device_map=device_map,
)

inputs = tokenizer("Today is a beautiful day!", return_tensors='pt').to(0)
out = model(**inputs)

Benchmark on llama + compile with meta-llama/Llama-3.2-1B-Instruct:

image

@zucchini-nlp zucchini-nlp requested a review from gante December 9, 2024 11:06
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp zucchini-nlp changed the title [WIP] Init cache on meta device Init cache on meta device Dec 10, 2024
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the offloaded cache: Is there a plausible use case where the current changes would be a fix? Given that is meant to be used in a low-resource setting (= no multiple gpus), I would leave it in its original implementation if multigpu is the only purpose of the changes. In offloading we want to be precise with devices, which is the opposite of meta.

Other than that LGTM :)

This PR has many moving pieces associated with it, so I'm leaving down a few questions to ensure we cross all i's and dot all t's:

  1. The original issue is long. Let's document this PR with a minimal example to reproduce the issue. It will be useful in case we need to understand why we did this;
  2. I'm assuming you ran tests locally, but let's write down which commands were run. (minimum: slow llama tests, slow cache tests)
  3. Have you benchmarked llama + static cache + compilation? If yes, leave a note in the PR header. If not, please double-check :)
  4. [if we want to keep the changes for the offloaded cache] Ditto for the offloaded cache, make sure it is benchmarked before and after these changes. I'm assuming existing tests would catch any correctness issue.

@@ -462,26 +462,6 @@ def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)

set_seed(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is related to what you wrote about offloaded caches + cuda graphs.

We test compilation in other places, so I agree it is fine to delete :)

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, exactly! Since we compile whenever a StaticCache is used, the model gets compiled for cache_implementation="offloaded_cache"

And these tests were not being run to catch the break in graphs. If we want to keep cache_implementation="offloaded_cache" working, we either disable compile specifically for this cache type or make this cache not instance of StaticCache. I am not sure if we are planning to keep the "auto-compile the forward if static cache" feature

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's indeed a better way to do this! Let's iterate a bit tho!

Comment on lines 1205 to 1209
if k_out.device.type == "meta":
k_out = torch.zeros(*k_out.size(), device=key_states.device, dtype=key_states.dtype)
v_out = torch.zeros(*v_out.size(), device=value_states.device, dtype=value_states.dtype)
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the else branch should do:

        k_out = self.key_cache[layer_idx]
        v_out = self.value_cache[layer_idx]
```
not sure memory wise if k_out is erased first or if you allocate more memory at this point. 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it shows more memory allocation for me. I will rewrite a bit to make code more easily inspectable

Comment on lines +318 to +319
"Hello I am doing a project for my school and I am trying to make a program that will allow me to input a",
"Hello I am doing a project for my school and I am trying to make a program that will allow me to use a",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test was never passing because it had no num_beams, the hub config hasn't changed since release. So I just fixed it

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks for fixing this !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before merging, let's make sure we use correct torch primitives / use the simplest code as possible !

zucchini-nlp and others added 2 commits January 21, 2025 11:49
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot better, just be careful with one change and should be good!

Comment on lines -1813 to -1829
self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
dtype=dtype,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we change the shape here?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we can move tensors from "meta" to "cuda" layer by layer, whenever the cache is updated for given layer. Otherwise, we can move the whole 5D cache once when the fist layer is updated, but I just wanted to be consistent with other static cache classes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, sounds good in that case indeed devices can be different!

@zucchini-nlp zucchini-nlp merged commit 373e50e into huggingface:main Jan 22, 2025
25 checks passed
bursteratom pushed a commit to bursteratom/transformers that referenced this pull request Jan 31, 2025
* init cache on meta device

* offloaded static + enable tests

* tests weren't running before  :(

* update

* fix mamba

* fix copies

* update

* address comments and fix tests

* fix copies

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* mamba fix

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
* init cache on meta device

* offloaded static + enable tests

* tests weren't running before  :(

* update

* fix mamba

* fix copies

* update

* address comments and fix tests

* fix copies

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* mamba fix

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-GPU setup: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)
5 participants