-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[generate
+ static cache + torch.compile
] ability to pass statically shaped 4D attention_mask
to the model forward
#29165
Comments
cc @gante |
@fxmarty I had a quick look at this -- we still have models (like We may want to work with padded tensors instead? In our TF/PT XLA implementation we have static shapes everywhere. |
@gante thank you! What do you mean by work with padded tensors? |
@fxmarty If we want to generate with |
Oh yes, I think that's what I am suggesting here! |
Feature request
Currently, the
attention_mask
passed is 2D and of dynamic shapes.This causes issues when using a compiled model with
model.forward = torch.compile(model.forward, mode="reduce-overhead")
, see pytorch/pytorch#120309 & #29114.Instead, we may want to pass directly 4D masks of static shape to the model, avoiding cuda graphs recomputation. This may allow the compile time to go down from >3 min to a mere 50s.
Motivation
/
Your contribution
/
The text was updated successfully, but these errors were encountered: