Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-enable SDPA's FA2 path #30070
Re-enable SDPA's FA2 path #30070
Changes from 13 commits
0f8e6ac
8898595
2471be2
fa14c73
76e75eb
9793add
eebfa09
701586e
0609b43
3ca589d
471c101
7355555
42692de
3f92c7d
c742e79
fbba245
4c41786
a94a441
38fb6f6
70d903f
86e223e
290a267
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@warner-benjamin I'm quite sure there is no way around this currently. We may revisit this with pytorch/pytorch#114823 if that helps.
But I get it that we should have a path where users don't use
attention_mask
input, and use torch.compile WITHOUTtorch.export
/torch.onnx.dynamo_export
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With my proposed modification to
SdpaAttention.forward
, I think it would be possible to check here forself.training
to dispatch to FA2 and Efficient if in eval? Otherwise, a config option to force SDPA FA2 would work.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@warner-benjamin I think you misunderstood my message here.
When using
torch.export
,torch.onnx.dynamo_export
, a single GraphModule is captured, no matter what, based on the sample input that is passed.On the contrary, simply when using an
OptimizedModule
from torch.compile (even with fullgraph=True), there is still possible full graph invalidation based on the input. You can check torchdynamo logs (torch._logging.set_logs(dynamo=logging.INFO, aot=logging.INFO, inductor=logging.INFO, graph_breaks=True, guards=True, recompiles=True, output_code=True, graph_code=True, graph=True)
), e.g. we have the following even with fullgraph=True:& you can see the related FX ops, with
is_causal
hard-coded:[2024-04-08 11:58:22,052] [0/1] torch._dynamo.output_graph.__graph: [DEBUG] call_function attn_output <built-in function scaled_dot_product_attention> (query_states_2, key_states_36, value_states_35) {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': False}
&
[2024-04-08 11:57:34,825] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function attn_output <built-in function scaled_dot_product_attention> (query_states_2, key_states_36, value_states_35) {'attn_mask': None, 'dropout_p': 0.0, 'is_causal': True}
What I am saying is that we want to avoid hard-coding
is_causal
. As notorch._dynamo.is_exporting()
is not available (I wish it was, as well astorch._dynamo.is_fullgraph_compiling()
pytorch/pytorch#120400), my safe bet is to always useattn_mask
in case torch.compile is used.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fxmarty I've been unable to trigger the full graph invalidation as you've shown. Could you share code to reproduce it?
I've tried my proposed changes with
torch.export
andtorch.compile
withfullgraph=True
,dynamic=True
, and bothfullgraph=True, dynamic=True
with PyTorch 2.2.2 and PyTorch 2.1.2 (minus export since it's a 2.2 feature). I have been unsuccessful in triggering any recompilations (outside of different input shapes withoutdynamic=True
, but that's too be expected).I'm not following why we cannot use the FA2 kernel during compiled training. We could dispatch via
self.training
to FA2/attn_mask
, add a config option to force FA2/attn_mask
, or both.As it currently stands, someone testing model training in eager mode would use the memory efficient FA2 kernel, and then when switching to
torch.compile
for additional memory savings and speed, their training would instead use more memory due to silently switching the kernel from FA2 to efficient.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, let me do next week.
Sure,
self.training
could be a decent optionThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@warner-benjamin using torch==2.2.2 and setting
in
_ignore_causal_mask_sdpa
, and tracing withfullgraph=True, dynamic=True
, I do getscaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
when usingis_causal=q_len > 1
.When tracing simply with
fullgraph=True
, I am getting the above (graph recompilation), which is not friendly withtorch.export
/ ONNX dynamo export.You can see the following script:
Thus in my opinion dropping the attention_mask by default in case we are using torch.compile is not the right choice for now.
We could have a
force_causal_mask
flag to inform users who want to make sure SDPA FA2 backend is usable with torch.compile/torchscript/etc, that would disableattn_mask
altogether and forceis_causal=True
. Would that be fine to you? I would suggest to do it in an other PR though.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Will fix compile+FA2 in a new PR once this one is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your thorough review!