-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: make slow tests green 🟢 #33138
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -243,30 +245,33 @@ def _ignore_causal_mask_sdpa( | |||
is_training: bool = False, | |||
) -> bool: | |||
""" | |||
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. | |||
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function had many >120 char lines, which did not fit on my screen :D The diff in docstrings/comments is exclusively due to breaking likes.
is_tracing = ( | ||
torch.jit.is_tracing() | ||
or isinstance(inputs_embeds, torch.fx.Proxy) | ||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_torchdynamo_compiling()
is a more comprehensive check.
I've replaced all occurrences of this pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is torchdynamo_compiling takes into account the version of torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
It uses a try/except, if the newer functions are not available it falls back to torch._dynamo.is_compiling()
(and, if that is not available, it means it can't compile)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static | ||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. | ||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using | ||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer true with #32227
Removed all occurrences of this comment.
# compiled static cache (removes the cache initialized in the previous check, to confirm we can | ||
# initialize the cache in full compiled mode) | ||
model._cache = None | ||
# end-to-end compiled dynamic cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the reference for end-to-end compilation is with dynamic cache -- I forgot to update this test before the last set of changes in the PR that introduced end-to-end compilation :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this includes recompile no? Or dynamic shapes are handled now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it issues a recompile
(but static cache doesn't work at the moment -- both things, recompilations and using static caches, need to be addressed)
@@ -779,8 +781,8 @@ def test_model_7b_logits_bf16(self): | |||
torch.allclose( | |||
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device), | |||
out.logits[0, 0, :15], | |||
atol=1e-3, | |||
rtol=1e-3, | |||
atol=1e-2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have expected results built for the GPUs in our CI, T4 and A10. However, if I navigate to the original commit and install the torch version at the time (torch 2.3), the test fails on my RTX4090
This means there are probably tiny device-related differences whose explanation is beyond the cuda compute capability major version.
As such, increased the tolerance. To be fair, 1e-2
is within the expected differences for 16-bit computations, just like 1e-5
is for 32-bit computations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
is_tracing = ( | ||
torch.jit.is_tracing() | ||
or isinstance(inputs_embeds, torch.fx.Proxy) | ||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is torchdynamo_compiling takes into account the version of torch?
# compiled static cache (removes the cache initialized in the previous check, to confirm we can | ||
# initialize the cache in full compiled mode) | ||
model._cache = None | ||
# end-to-end compiled dynamic cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this includes recompile no? Or dynamic shapes are handled now?
What does this PR do?
Part 1 of #32685 -- update our tests to be sure we don't break things 🤗
Makes slow
llama
tests happy on my local environment. Some tests are still failing on our slow CI, mostly due to hardware (e.g. out of memory), but they do not have an impact on #32685 [in other words, in checking the correctness of new changes].The exception is
test_compile_static_cache
, which passes when run in isolation but fails if run with other tests due toaccelerate
+torch.compile
incompatibilities (see our internal discussion here).