Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enable SDPA's FA2 path #30070

Merged
merged 22 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,59 @@ def _unmask_unattended(

return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))

@staticmethod
def _ignore_causal_mask_sdpa(
attention_mask: Optional[torch.Tensor],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
) -> bool:
"""
Detects whether the attention_mask can be ignored in case we use PyTorch's SDPA.

In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
"""

batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
key_value_length = query_length + past_key_values_length

is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

ignore_causal_mask = False

if attention_mask is None:
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
# Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
#
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
ignore_causal_mask = not is_tracing
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@warner-benjamin I'm quite sure there is no way around this currently. We may revisit this with pytorch/pytorch#114823 if that helps.

But I get it that we should have a path where users don't use attention_mask input, and use torch.compile WITHOUT torch.export/torch.onnx.dynamo_export.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With my proposed modification to SdpaAttention.forward, I think it would be possible to check here for self.training to dispatch to FA2 and Efficient if in eval? Otherwise, a config option to force SDPA FA2 would work.

Copy link
Contributor Author

@fxmarty fxmarty Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@warner-benjamin I think you misunderstood my message here.

When using torch.export, torch.onnx.dynamo_export, a single GraphModule is captured, no matter what, based on the sample input that is passed.

On the contrary, simply when using an OptimizedModule from torch.compile (even with fullgraph=True), there is still possible full graph invalidation based on the input. You can check torchdynamo logs (torch._logging.set_logs(dynamo=logging.INFO, aot=logging.INFO, inductor=logging.INFO, graph_breaks=True, guards=True, recompiles=True, output_code=True, graph_code=True, graph=True)), e.g. we have the following even with fullgraph=True:

[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /home/felix/transformers/src/transformers/models/llama/modeling_llama.py:1163
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG]     triggered by the following guard failure(s):
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG]     - tensor 'L['input_ids']' size mismatch at index 1. expected 4, actual 1

& you can see the related FX ops, with is_causal hard-coded:

[2024-04-08 11:58:22,052] [0/1] torch._dynamo.output_graph.__graph: [DEBUG] call_function attn_output <built-in function scaled_dot_product_attention> (query_states_2, key_states_36, value_states_35) {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': False}

&

[2024-04-08 11:57:34,825] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function attn_output <built-in function scaled_dot_product_attention> (query_states_2, key_states_36, value_states_35) {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': True}

What I am saying is that we want to avoid hard-coding is_causal. As no torch._dynamo.is_exporting() is not available (I wish it was, as well as torch._dynamo.is_fullgraph_compiling() pytorch/pytorch#120400), my safe bet is to always use attn_mask in case torch.compile is used.

Copy link
Contributor

@warner-benjamin warner-benjamin Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty I've been unable to trigger the full graph invalidation as you've shown. Could you share code to reproduce it?

I've tried my proposed changes with torch.export and torch.compile with fullgraph=True, dynamic=True, and both fullgraph=True, dynamic=True with PyTorch 2.2.2 and PyTorch 2.1.2 (minus export since it's a 2.2 feature). I have been unsuccessful in triggering any recompilations (outside of different input shapes without dynamic=True, but that's too be expected).

What I am saying is that we want to avoid hard-coding is_causal. As no torch._dynamo.is_exporting() is not available (I wish it was, as well as torch._dynamo.is_fullgraph_compiling() pytorch/pytorch#120400), my safe bet is to always use attn_mask in case torch.compile is used.

I'm not following why we cannot use the FA2 kernel during compiled training. We could dispatch via self.training to FA2/attn_mask, add a config option to force FA2/attn_mask, or both.

As it currently stands, someone testing model training in eager mode would use the memory efficient FA2 kernel, and then when switching to torch.compile for additional memory savings and speed, their training would instead use more memory due to silently switching the kernel from FA2 to efficient.

Copy link
Contributor Author

@fxmarty fxmarty Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty I've been unable to trigger the full graph invalidation as you've shown. Could you share code to reproduce it?

Sure, let me do next week.

I'm not following why we cannot use the FA2 kernel during compiled training. We could dispatch via self.training to FA2/attn_mask, add a config option to force FA2/attn_mask, or both.

Sure, self.training could be a decent option

Copy link
Contributor Author

@fxmarty fxmarty Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@warner-benjamin using torch==2.2.2 and setting

        if attention_mask is None:
            ignore_causal_mask = True

in _ignore_causal_mask_sdpa, and tracing with fullgraph=True, dynamic=True, I do get scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool when using is_causal=q_len > 1.

When tracing simply with fullgraph=True, I am getting the above (graph recompilation), which is not friendly with torch.export / ONNX dynamo export.

You can see the following script:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import copy

import logging
torch._logging.set_logs(dynamo=logging.INFO, aot=logging.INFO, inductor=logging.INFO, graph_breaks=True, guards=True, recompiles=True, output_code=True, graph_code=True, graph=True)

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)

with torch.device("cuda"):
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-chat-hf",
        torch_dtype=torch.float16,
        attn_implementation="sdpa",
    )

model = model.eval()
inps = {
    "input_ids": torch.tensor([[5, 6, 7, 9]], dtype=torch.int64).to("cuda"),
    "use_cache": True,
}

with torch.no_grad():
    out = model(**inps)

    pkv = out.past_key_values
    pkv_0 = copy.deepcopy(pkv)
    pkv_1 = copy.deepcopy(pkv)
    pkv_2 = copy.deepcopy(pkv)
    pkv_3 = copy.deepcopy(pkv)

with torch.no_grad():

    print("--------- WITHOUT TORCH.COMPILE + input_ids > 1")
    inps_here = copy.deepcopy(inps)
    print("inps_here", inps_here.keys())

    res_several = model(**inps_here, past_key_values=pkv_0)

    print("--------- WITHOUT TORCH.COMPILE + input_ids == 1")
    inps_here = copy.deepcopy(inps)
    inps_here["input_ids"] = inps_here["input_ids"][:, :1]
    print("inps_here", inps_here.keys())

    res_one = model(**inps_here, past_key_values=pkv_1)

    print("compiling...")

    torch.cuda.synchronize()

    model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

    print("--------- WITH TORCH.COMPILE + input_ids > 1")
    inps_here = copy.deepcopy(inps)
    print("inps_here", inps_here.keys())

    res_several_comp = model(**inps_here, past_key_values=pkv_2)

    print("--------- WITH TORCH.COMPILE + input_ids == 1")
    inps_here = copy.deepcopy(inps)
    inps_here["input_ids"] = inps_here["input_ids"][:, :1]
    print("inps_here", inps_here.keys())

    res_one_comp = model(**inps_here, past_key_values=pkv_3)

assert torch.allclose(res_several, res_several_comp)
assert torch.allclose(res_one, res_one_comp)

Thus in my opinion dropping the attention_mask by default in case we are using torch.compile is not the right choice for now.

We could have a force_causal_mask flag to inform users who want to make sure SDPA FA2 backend is usable with torch.compile/torchscript/etc, that would disable attn_mask altogether and force is_causal=True. Would that be fine to you? I would suggest to do it in an other PR though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Will fix compile+FA2 in a new PR once this one is merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your thorough review!

else:
if len(attention_mask.shape) == 4:
expected_shape = (batch_size, 1, query_length, key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True
elif key_value_length == query_length:
ignore_causal_mask = True

# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.

return ignore_causal_mask


def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
Expand Down
39 changes: 30 additions & 9 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,15 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -908,9 +911,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -974,24 +975,44 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

ignore_causal_mask = False
if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
)

if ignore_causal_mask:
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
39 changes: 30 additions & 9 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,15 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -888,9 +891,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -960,24 +961,44 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

ignore_causal_mask = False
if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
)

if ignore_causal_mask:
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
40 changes: 30 additions & 10 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,6 @@ def forward(
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = attention_mask
# if attention_mask is not None and cache_position is not None:
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

Expand All @@ -667,12 +666,15 @@ def forward(
key_states = key_states.contiguous()
value_states = value_states.contiguous()

# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
# relying on the `is_causal` argument.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down Expand Up @@ -987,9 +989,7 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -1053,24 +1053,44 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

ignore_causal_mask = False
if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
)

if ignore_causal_mask:
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
Loading