-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Multihead dispatch benchmark #273
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import Any, Dict | ||
|
||
import torch | ||
import torch.nn as nn | ||
import triton | ||
|
||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print | ||
from xformers.components import MultiHeadDispatch | ||
from xformers.components.attention import ScaledDotProduct | ||
|
||
SHAPES = [ | ||
(8, 384, 128), | ||
(8, 784, 512), | ||
(4, 1024, 768), | ||
(4, 2048, 1024), | ||
(2, 2048, 2048), | ||
(2, 2048, 4096), | ||
(2, 4096, 4096), | ||
(1, 2048, 12288), | ||
] | ||
|
||
N_HEADS = [4] | ||
|
||
|
||
def bench_multihead_dispatch(backward: bool): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super clean, nice |
||
device = torch.device("cuda") | ||
bw = "+bw" if backward else "" | ||
|
||
for dtype in [torch.float16, torch.float32]: | ||
results: Dict[str, Any] = {} | ||
|
||
for B, M, K in SHAPES: | ||
for heads in N_HEADS: | ||
xf_multi_head = MultiHeadDispatch( | ||
dim_model=K, | ||
residual_dropout=0.0, | ||
num_heads=heads, | ||
attention=ScaledDotProduct(), | ||
bias=True, | ||
).to(device=device, dtype=dtype) | ||
torch_multi_head = nn.MultiheadAttention( | ||
embed_dim=K, num_heads=heads, batch_first=True | ||
).to(device=device, dtype=dtype) | ||
|
||
query = torch.randn( | ||
(B, M, K), requires_grad=backward, device=device, dtype=dtype | ||
) | ||
|
||
def torch_mha(): | ||
y, _ = torch_multi_head(query=query, key=query, value=query) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could be nice to bench with and without self attention, I know that self is prevalent in vision but for NLP that would not always be the case. Could be a boolean here, on top of "backwards", self_attention or not ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea, just added this :) |
||
if backward: | ||
torch.norm(y).backward() | ||
return y | ||
|
||
def xformers_mha(): | ||
y = xf_multi_head(query=query, key=query, value=query) | ||
if backward: | ||
torch.norm(y).backward() | ||
return y | ||
|
||
for testcase in [ | ||
TestCase(torch_mha, f"torch - fw{bw}"), | ||
TestCase(xformers_mha, f"xf - fw{bw}"), | ||
]: | ||
time = triton.testing.do_bench(testcase.function)[0] | ||
key = f"B={B}, M={M}, K={K}, N_HEADS={heads}" | ||
if key not in results: | ||
results[key] = {} | ||
|
||
results[key][testcase.name] = f"{time:.2f}" | ||
|
||
pretty_print( | ||
results, | ||
title=f"\n --- Type: {dtype} --- ", | ||
units="runtime in ms, lower is better", | ||
) | ||
pretty_plot( | ||
results, | ||
title=f"MHA-FW{bw}-{dtype}", | ||
units="runtime in ms, lower is better", | ||
dash_key="torch", | ||
) | ||
|
||
|
||
for bw in [False, True]: | ||
bench_multihead_dispatch(bw) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch, thanks