-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-enable SDPA's FA2 path #30070
Re-enable SDPA's FA2 path #30070
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
is not failing locally, I don't get it. Besides, it is labeled as
is unrelated and also failing on main... Not sure why it appears here? Other than that, Cohere/Llama/Gemma slow tests pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some comments and suggested fixes to enable torch.compile with fullgraph=True
and prevent applying bidirectional instead of causal attention. The suggestions also apply to gemma and cohere models.
if attention_mask is None: | ||
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or | ||
# or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). | ||
# Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. | ||
# | ||
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). | ||
ignore_causal_mask = not is_tracing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@warner-benjamin I'm quite sure there is no way around this currently. We may revisit this with pytorch/pytorch#114823 if that helps.
But I get it that we should have a path where users don't use attention_mask
input, and use torch.compile WITHOUT torch.export
/torch.onnx.dynamo_export
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With my proposed modification to SdpaAttention.forward
, I think it would be possible to check here for self.training
to dispatch to FA2 and Efficient if in eval? Otherwise, a config option to force SDPA FA2 would work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@warner-benjamin I think you misunderstood my message here.
When using torch.export
, torch.onnx.dynamo_export
, a single GraphModule is captured, no matter what, based on the sample input that is passed.
On the contrary, simply when using an OptimizedModule
from torch.compile (even with fullgraph=True), there is still possible full graph invalidation based on the input. You can check torchdynamo logs (torch._logging.set_logs(dynamo=logging.INFO, aot=logging.INFO, inductor=logging.INFO, graph_breaks=True, guards=True, recompiles=True, output_code=True, graph_code=True, graph=True)
), e.g. we have the following even with fullgraph=True:
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /home/felix/transformers/src/transformers/models/llama/modeling_llama.py:1163
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2024-04-08 11:58:14,174] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['input_ids']' size mismatch at index 1. expected 4, actual 1
& you can see the related FX ops, with is_causal
hard-coded:
[2024-04-08 11:58:22,052] [0/1] torch._dynamo.output_graph.__graph: [DEBUG] call_function attn_output <built-in function scaled_dot_product_attention> (query_states_2, key_states_36, value_states_35) {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': False}
&
[2024-04-08 11:57:34,825] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function attn_output <built-in function scaled_dot_product_attention> (query_states_2, key_states_36, value_states_35) {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': True}
What I am saying is that we want to avoid hard-coding is_causal
. As no torch._dynamo.is_exporting()
is not available (I wish it was, as well as torch._dynamo.is_fullgraph_compiling()
pytorch/pytorch#120400), my safe bet is to always use attn_mask
in case torch.compile is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fxmarty I've been unable to trigger the full graph invalidation as you've shown. Could you share code to reproduce it?
I've tried my proposed changes with torch.export
and torch.compile
with fullgraph=True
, dynamic=True
, and both fullgraph=True, dynamic=True
with PyTorch 2.2.2 and PyTorch 2.1.2 (minus export since it's a 2.2 feature). I have been unsuccessful in triggering any recompilations (outside of different input shapes without dynamic=True
, but that's too be expected).
What I am saying is that we want to avoid hard-coding
is_causal
. As notorch._dynamo.is_exporting()
is not available (I wish it was, as well astorch._dynamo.is_fullgraph_compiling()
pytorch/pytorch#120400), my safe bet is to always useattn_mask
in case torch.compile is used.
I'm not following why we cannot use the FA2 kernel during compiled training. We could dispatch via self.training
to FA2/attn_mask
, add a config option to force FA2/attn_mask
, or both.
As it currently stands, someone testing model training in eager mode would use the memory efficient FA2 kernel, and then when switching to torch.compile
for additional memory savings and speed, their training would instead use more memory due to silently switching the kernel from FA2 to efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fxmarty I've been unable to trigger the full graph invalidation as you've shown. Could you share code to reproduce it?
Sure, let me do next week.
I'm not following why we cannot use the FA2 kernel during compiled training. We could dispatch via self.training to FA2/attn_mask, add a config option to force FA2/attn_mask, or both.
Sure, self.training
could be a decent option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@warner-benjamin using torch==2.2.2 and setting
if attention_mask is None:
ignore_causal_mask = True
in _ignore_causal_mask_sdpa
, and tracing with fullgraph=True, dynamic=True
, I do get scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
when using is_causal=q_len > 1
.
When tracing simply with fullgraph=True
, I am getting the above (graph recompilation), which is not friendly with torch.export
/ ONNX dynamo export.
You can see the following script:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import copy
import logging
torch._logging.set_logs(dynamo=logging.INFO, aot=logging.INFO, inductor=logging.INFO, graph_breaks=True, guards=True, recompiles=True, output_code=True, graph_code=True, graph=True)
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
with torch.device("cuda"):
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
torch_dtype=torch.float16,
attn_implementation="sdpa",
)
model = model.eval()
inps = {
"input_ids": torch.tensor([[5, 6, 7, 9]], dtype=torch.int64).to("cuda"),
"use_cache": True,
}
with torch.no_grad():
out = model(**inps)
pkv = out.past_key_values
pkv_0 = copy.deepcopy(pkv)
pkv_1 = copy.deepcopy(pkv)
pkv_2 = copy.deepcopy(pkv)
pkv_3 = copy.deepcopy(pkv)
with torch.no_grad():
print("--------- WITHOUT TORCH.COMPILE + input_ids > 1")
inps_here = copy.deepcopy(inps)
print("inps_here", inps_here.keys())
res_several = model(**inps_here, past_key_values=pkv_0)
print("--------- WITHOUT TORCH.COMPILE + input_ids == 1")
inps_here = copy.deepcopy(inps)
inps_here["input_ids"] = inps_here["input_ids"][:, :1]
print("inps_here", inps_here.keys())
res_one = model(**inps_here, past_key_values=pkv_1)
print("compiling...")
torch.cuda.synchronize()
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
print("--------- WITH TORCH.COMPILE + input_ids > 1")
inps_here = copy.deepcopy(inps)
print("inps_here", inps_here.keys())
res_several_comp = model(**inps_here, past_key_values=pkv_2)
print("--------- WITH TORCH.COMPILE + input_ids == 1")
inps_here = copy.deepcopy(inps)
inps_here["input_ids"] = inps_here["input_ids"][:, :1]
print("inps_here", inps_here.keys())
res_one_comp = model(**inps_here, past_key_values=pkv_3)
assert torch.allclose(res_several, res_several_comp)
assert torch.allclose(res_one, res_one_comp)
Thus in my opinion dropping the attention_mask by default in case we are using torch.compile is not the right choice for now.
We could have a force_causal_mask
flag to inform users who want to make sure SDPA FA2 backend is usable with torch.compile/torchscript/etc, that would disable attn_mask
altogether and force is_causal=True
. Would that be fine to you? I would suggest to do it in an other PR though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Will fix compile+FA2 in a new PR once this one is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your thorough review!
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@ArthurZucker done, also running slow tests for llama/gemma/cohere gives
which also fail on main on A100 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from transformers.cache_utils import StaticCache
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
with torch.device("cuda"):
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
torch_dtype=torch.float16,
attn_implementation="sdpa",
)
inputs = tokenizer(
["I would", "Today I am in Paris and", "I am"], padding=True, return_tensors="pt"
).to(model.device)
new_tokens = 10
gen_config = GenerationConfig(
max_new_tokens=new_tokens,
min_new_tokens=new_tokens,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
num_beams=1,
do_sample=False,
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
print("----- GENERATE WITHOUT COMPILE")
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("decoded", decoded)
print("compiling...")
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
print("Finished compile call")
print("----- GENERATE WITH COMPILE")
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("decoded", decoded)
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("decoded", decoded) works as expected as well (on 2.3 RC, see pytorch/pytorch#121943 for 2.2.2, happening on main) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like that the black magic is hidden from use with a if magic : return None.
This reverts commit 05bdef1.
* tentatively re-enable FA2 + SDPA * better comment * _ignore_causal_mask_sdpa as staticmethod * type hints * use past_seen_tokens instead * enable copied from for sdpa * ruff * llama simplifications on review * remove unnecessary self.is_causal check * fix copies * cleaning * precise message * better doc * add test * simplify * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
As per title, SDPA can't dispatch to FA2 in case the
attn_mask
argument is notNone
, which was always the case in recent releases following llama/gemma refactor, which led to memory issues #30010.In case we compile and use
fullgraph=True
, the graph is recaptured for differentq_len
, so havingis_causal=causal_mask is None and q_len > 1,
is fineReference pytorch/pytorch#108108