Skip to content

Commit

Permalink
Llama: make slow tests green 🟢 (huggingface#33138)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and BernardZach committed Dec 5, 2024
1 parent 4899a45 commit 65c7162
Show file tree
Hide file tree
Showing 31 changed files with 39 additions and 180 deletions.
42 changes: 20 additions & 22 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import torch

from .utils.import_utils import is_torchdynamo_compiling


@dataclass
class AttentionMaskConverter:
Expand Down Expand Up @@ -243,30 +245,33 @@ def _ignore_causal_mask_sdpa(
is_training: bool = False,
) -> bool:
"""
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
passed).
"""

_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
key_value_length = query_length + past_key_values_length

is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()

ignore_causal_mask = False

if attention_mask is None:
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
# shape, thus SDPA's `is_causal` argument is rightfully updated
# (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
# `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
# hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
# Thus, we only set `ignore_causal_mask = True` if the model is set to training.
#
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
# ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
if (
(is_training or not is_tracing)
and (query_length == 1 or key_value_length == query_length)
Expand All @@ -281,8 +286,9 @@ def _ignore_causal_mask_sdpa(
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True

# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
# the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
# SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.

Expand Down Expand Up @@ -363,11 +369,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()

ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask=attention_mask,
Expand Down Expand Up @@ -439,11 +441,7 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
_, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length

is_tracing = (
torch.jit.is_tracing()
or isinstance(mask, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
if not is_tracing and torch.all(mask == 1):
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,11 +790,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,11 +1433,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,11 +632,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,11 +912,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,11 +1163,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,11 +931,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,11 +846,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,11 +1017,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,11 +941,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,11 +1474,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,11 +1136,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,11 +1038,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import is_accelerate_available, logging
from ...utils.backbone_utils import load_backbone
from ...utils.import_utils import is_torchdynamo_compiling
from .configuration_mask2former import Mask2FormerConfig


Expand Down Expand Up @@ -1999,11 +2000,7 @@ def __init__(self, hidden_size: int, num_heads: int, mask_feature_size: torch.Te
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))

is_tracing = (
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
is_tracing = torch.jit.is_tracing() or isinstance(outputs, torch.fx.Proxy) or is_torchdynamo_compiling()
# Sum up over the channels
if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
requires_backends,
)
from ...utils.backbone_utils import load_backbone
from ...utils.import_utils import is_torchdynamo_compiling
from ..detr import DetrConfig
from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
Expand Down Expand Up @@ -1680,11 +1681,7 @@ def get_logits(self, outputs: MaskFormerModelOutput) -> Tuple[Tensor, Tensor, Di
# get the auxiliary predictions (one for each decoder's layer)
auxiliary_logits: List[str, Tensor] = []

is_tracing = (
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
is_tracing = torch.jit.is_tracing() or isinstance(outputs, torch.fx.Proxy) or is_torchdynamo_compiling()
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
if self.config.use_auxiliary_loss:
stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states)
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,11 +852,6 @@ def _update_causal_mask(
use_cache: bool,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,11 +1121,6 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand Down
Loading

0 comments on commit 65c7162

Please sign in to comment.