Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: make slow tests green 🟢 #33138

Merged
merged 3 commits into from
Aug 27, 2024
Merged

Conversation

gante
Copy link
Member

@gante gante commented Aug 27, 2024

What does this PR do?

Part 1 of #32685 -- update our tests to be sure we don't break things 🤗

Makes slow llama tests happy on my local environment. Some tests are still failing on our slow CI, mostly due to hardware (e.g. out of memory), but they do not have an impact on #32685 [in other words, in checking the correctness of new changes].

The exception is test_compile_static_cache, which passes when run in isolation but fails if run with other tests due to accelerate + torch.compile incompatibilities (see our internal discussion here).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante marked this pull request as ready for review August 27, 2024 10:39
@@ -243,30 +245,33 @@ def _ignore_causal_mask_sdpa(
is_training: bool = False,
) -> bool:
"""
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function had many >120 char lines, which did not fit on my screen :D The diff in docstrings/comments is exclusively due to breaking likes.

is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_torchdynamo_compiling() is a more comprehensive check.

I've replaced all occurrences of this pattern.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is torchdynamo_compiling takes into account the version of torch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

It uses a try/except, if the newer functions are not available it falls back to torch._dynamo.is_compiling() (and, if that is not available, it means it can't compile)

Comment on lines -793 to -797
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer true with #32227

Removed all occurrences of this comment.

# compiled static cache (removes the cache initialized in the previous check, to confirm we can
# initialize the cache in full compiled mode)
model._cache = None
# end-to-end compiled dynamic cache
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reference for end-to-end compilation is with dynamic cache -- I forgot to update this test before the last set of changes in the PR that introduced end-to-end compilation :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this includes recompile no? Or dynamic shapes are handled now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it issues a recompile

(but static cache doesn't work at the moment -- both things, recompilations and using static caches, need to be addressed)

@@ -779,8 +781,8 @@ def test_model_7b_logits_bf16(self):
torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
out.logits[0, 0, :15],
atol=1e-3,
rtol=1e-3,
atol=1e-2,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have expected results built for the GPUs in our CI, T4 and A10. However, if I navigate to the original commit and install the torch version at the time (torch 2.3), the test fails on my RTX4090

This means there are probably tiny device-related differences whose explanation is beyond the cuda compute capability major version.

As such, increased the tolerance. To be fair, 1e-2 is within the expected differences for 16-bit computations, just like 1e-5 is for 32-bit computations.

@gante gante requested a review from ArthurZucker August 27, 2024 10:57
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is torchdynamo_compiling takes into account the version of torch?

# compiled static cache (removes the cache initialized in the previous check, to confirm we can
# initialize the cache in full compiled mode)
model._cache = None
# end-to-end compiled dynamic cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this includes recompile no? Or dynamic shapes are handled now?

@gante gante merged commit c6b23fd into huggingface:main Aug 27, 2024
22 checks passed
@gante gante deleted the fix_llama_tests branch August 27, 2024 13:44
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants