Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[generate + static cache + torch.compile] ability to pass statically shaped 4D attention_mask to the model forward #29165

Closed
fxmarty opened this issue Feb 21, 2024 · 5 comments · Fixed by #32227
Labels
Cache Compilation Issues related to torchdynamo and torchinductor Generation

Comments

@fxmarty
Copy link
Contributor

fxmarty commented Feb 21, 2024

Feature request

Currently, the attention_mask passed is 2D and of dynamic shapes.

This causes issues when using a compiled model with model.forward = torch.compile(model.forward, mode="reduce-overhead"), see pytorch/pytorch#120309 & #29114.

----- in forward 0
name=input_ids, shape=torch.Size([2, 7]), stride=(7, 1), dtype=torch.int64, device=cuda:0
name=position_ids, shape=torch.Size([2, 7]), stride=(7, 1), dtype=torch.int64, device=cuda:0
name=cache_position, shape=torch.Size([7]), stride=(1,), dtype=torch.int64, device=cuda:0
name=past_key_values, value=None
name=use_cache, value=True
name=attention_mask, shape=torch.Size([2, 7]), stride=(7, 1), dtype=torch.int64, device=cuda:0
forward call latency: 1784.737 ms      <---------------------------- EXTREMELY SLOW.
----- in forward 1
name=input_ids, shape=torch.Size([2, 1]), stride=(1, 1), dtype=torch.int64, device=cuda:0
name=position_ids, shape=torch.Size([2, 1]), stride=(1, 1), dtype=torch.int64, device=cuda:0
name=cache_position, shape=torch.Size([1]), stride=(1,), dtype=torch.int64, device=cuda:0
name=past_key_values, value=None
name=use_cache, value=True
name=attention_mask, shape=torch.Size([2, 8]), stride=(8, 1), dtype=torch.int64, device=cuda:0
forward call latency: 1851.579 ms      <---------------------------- EXTREMELY SLOW.
----- in forward 2
name=input_ids, shape=torch.Size([2, 1]), stride=(1, 1), dtype=torch.int64, device=cuda:0
name=position_ids, shape=torch.Size([2, 1]), stride=(1, 1), dtype=torch.int64, device=cuda:0
name=cache_position, shape=torch.Size([1]), stride=(1,), dtype=torch.int64, device=cuda:0
name=past_key_values, value=None
name=use_cache, value=True
name=attention_mask, shape=torch.Size([2, 9]), stride=(9, 1), dtype=torch.int64, device=cuda:0
forward call latency: 1421.504 ms      <---------------------------- EXTREMELY SLOW.
----- in forward 3
name=input_ids, shape=torch.Size([2, 1]), stride=(1, 1), dtype=torch.int64, device=cuda:0
name=position_ids, shape=torch.Size([2, 1]), stride=(1, 1), dtype=torch.int64, device=cuda:0
name=cache_position, shape=torch.Size([1]), stride=(1,), dtype=torch.int64, device=cuda:0
name=past_key_values, value=None
name=use_cache, value=True
name=attention_mask, shape=torch.Size([2, 10]), stride=(10, 1), dtype=torch.int64, device=cuda:0
forward call latency: 1740.283 ms      <---------------------------- EXTREMELY SLOW.

Instead, we may want to pass directly 4D masks of static shape to the model, avoiding cuda graphs recomputation. This may allow the compile time to go down from >3 min to a mere 50s.

Motivation

/

Your contribution

/

@amyeroberts
Copy link
Collaborator

cc @gante

@fxmarty fxmarty added the Compilation Issues related to torchdynamo and torchinductor label Feb 28, 2024
@huggingface huggingface deleted a comment from github-actions bot Mar 25, 2024
@gante
Copy link
Member

gante commented Mar 27, 2024

@fxmarty I had a quick look at this -- we still have models (like gpt2) that exclusively accept 2D masks. We would have to rework that before making generate prepare a 4D mask.

We may want to work with padded tensors instead? In our TF/PT XLA implementation we have static shapes everywhere.

@fxmarty
Copy link
Contributor Author

fxmarty commented Mar 27, 2024

@gante thank you! What do you mean by work with padded tensors?

@gante
Copy link
Member

gante commented Mar 27, 2024

@fxmarty If we want to generate with max_length=512, the attention mask is always kept with sequence length = 512. Data is moved around inside the tensor as needed.

@fxmarty
Copy link
Contributor Author

fxmarty commented Mar 27, 2024

Oh yes, I think that's what I am suggesting here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cache Compilation Issues related to torchdynamo and torchinductor Generation
Projects
None yet
3 participants