-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Multihead dispatch benchmark #273
Conversation
--- Type: torch.float16 ---
--- Type: torch.float32 ---
--- Type: torch.float16 ---
--- Type: torch.float32 ---
|
@@ -238,7 +238,6 @@ attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dr | |||
# "multi_head" will be responsible for the forward | |||
multi_head = ( | |||
MultiHeadDispatch( | |||
seq_len=SEQ, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch, thanks
N_HEADS = [4] | ||
|
||
|
||
def bench_multihead_dispatch(backward: bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super clean, nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect, thanks @dianaml0 ! The results are very interesting, I for one thought that we could be a little slow in this particular spot, does not seem to be the case ?
) | ||
|
||
def torch_mha(): | ||
y, _ = torch_multi_head(query=query, key=query, value=query) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could be nice to bench with and without self attention, I know that self is prevalent in vision but for NLP that would not always be the case. Could be a boolean here, on top of "backwards", self_attention or not ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, just added this :)
Codecov Report
@@ Coverage Diff @@
## main #273 +/- ##
=======================================
Coverage 92.80% 92.80%
=======================================
Files 61 61
Lines 3363 3363
=======================================
Hits 3121 3121
Misses 242 242
Flags with carried forward coverage won't be shown. Click here to find out more. Continue to review full report at Codecov.
|
Updated results with and without self attention: --- Type: torch.float16 ---
--- Type: torch.float32 ---
--- Type: torch.float16 ---
--- Type: torch.float32 ---
--- Type: torch.float16 ---
--- Type: torch.float32 ---
--- Type: torch.float16 ---
--- Type: torch.float32 ---
|
What does this PR do?
Benchmarks xFormers MultiHeadDispatch again Pytorch's version.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.