-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrap _prepare_4d_causal_attention_mask
as a leaf function
#27236
Wrap _prepare_4d_causal_attention_mask
as a leaf function
#27236
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change LGTM - thanks for adding!
Happy to merge with @younesbelkada's approval.
cc @patrickvonplaten for reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good! My only question is that this seems to wrap _prepare_4d_causal_attention_mask
all the time as long as fx is available, is it possible to wrap it only if users perform model tracing?
No it needs to happen at the top-module level. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, reading a bit the docs of the method it just registers the method as a "leaf function" without modifying anything - seems safe to me
https://pytorch.org/docs/stable/_modules/torch/fx/_symbolic_trace.html#wrap
It is, the only difference it implies is that it will not be possible to edit this function via |
The documentation is not available anymore as the PR was closed or merged. |
Actually I have a workaround that works for my purposes in |
I think it's fine to leave as-is - people might want to use torch tracing outside of |
…ace#27236) Wrap _prepare_4d_causal_attention_mask as a leaf function
What does this PR do?
This wraps the
_prepare_4d_causal_attention_mask
as a FX leaf function for similar reasons than here.The only consequence it has is that it will not be possible to edit this function by using
torch.fx
. It is not a big deal at all, but I will remove this constraint as soon as possible.