Skip to content

Commit

Permalink
[feat] Multihead dispatch benchmark (#273)
Browse files Browse the repository at this point in the history
* multihead dispatch benchmark
* benchmark non self attention case
* include in table key

authored-by: Diana Liskovich <dianaml@devfair0471.h2.fair>
  • Loading branch information
dianaml0 authored Apr 13, 2022
1 parent f78ba0a commit 0c555e9
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
1 change: 0 additions & 1 deletion HOWTO.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dr
# "multi_head" will be responsible for the forward
multi_head = (
MultiHeadDispatch(
seq_len=SEQ,
dim_model=EMB,
residual_dropout=DROPOUT,
num_heads=HEADS,
Expand Down
105 changes: 105 additions & 0 deletions xformers/benchmarks/benchmark_multi_head_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from typing import Any, Dict

import torch
import torch.nn as nn
import triton

from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components import MultiHeadDispatch
from xformers.components.attention import ScaledDotProduct

SHAPES = [
(8, 384, 128),
(8, 784, 512),
(4, 1024, 768),
(4, 2048, 1024),
(2, 2048, 2048),
(2, 2048, 4096),
(2, 4096, 4096),
(1, 2048, 12288),
]

N_HEADS = [4]


def bench_multihead_dispatch(backward: bool, self_attention: bool):
device = torch.device("cuda")
bw = "+bw" if backward else ""
sa = " (self_attn)" if self_attention else ""

for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}

for B, M, K in SHAPES:
for heads in N_HEADS:
xf_multi_head = MultiHeadDispatch(
dim_model=K,
residual_dropout=0.0,
num_heads=heads,
attention=ScaledDotProduct(),
bias=True,
).to(device=device, dtype=dtype)
torch_multi_head = nn.MultiheadAttention(
embed_dim=K, num_heads=heads, batch_first=True
).to(device=device, dtype=dtype)

q = torch.randn(
(B, M, K), requires_grad=backward, device=device, dtype=dtype
)

if self_attention:
k = q
v = q
else:
k = torch.randn(
(B, M, K), requires_grad=backward, device=device, dtype=dtype
)
v = torch.randn(
(B, M, K), requires_grad=backward, device=device, dtype=dtype
)

def torch_mha():
y, _ = torch_multi_head(query=q, key=k, value=v)
if backward:
torch.norm(y).backward()
return y

def xformers_mha():
y = xf_multi_head(query=q, key=k, value=v)
if backward:
torch.norm(y).backward()
return y

for testcase in [
TestCase(torch_mha, f"torch - fw{bw}{sa}"),
TestCase(xformers_mha, f"xf - fw{bw}{sa}"),
]:
time = triton.testing.do_bench(testcase.function)[0]
key = f"B={B}, M={M}, K={K}, N_HEADS={heads}"
if key not in results:
results[key] = {}

results[key][testcase.name] = f"{time:.2f}"

pretty_print(
results,
title=f"\n --- Type: {dtype} --- ",
units="runtime in ms, lower is better",
)
pretty_plot(
results,
title=f"MHA-FW{bw}-{dtype}",
units="runtime in ms, lower is better",
dash_key="torch",
)


for bw in [False, True]:
for self_attention in [False, True]:
bench_multihead_dispatch(bw, self_attention)
2 changes: 1 addition & 1 deletion xformers/benchmarks/benchmark_revnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def reversible_step():
results,
title=f"RevNet-FW{bw}-{dtype}",
units="runtime in ms, lower is better",
dash_key="pytorch",
dash_key="torch",
)


Expand Down

0 comments on commit 0c555e9

Please sign in to comment.