From 235cd83ab115229ec20bdf1f520a5f82d4300707 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 13 Apr 2022 11:42:43 -0700 Subject: [PATCH] multihead dispatch benchmark --- HOWTO.md | 1 - .../benchmark_multi_head_dispatch.py | 92 +++++++++++++++++++ xformers/benchmarks/benchmark_revnet.py | 2 +- 3 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 xformers/benchmarks/benchmark_multi_head_dispatch.py diff --git a/HOWTO.md b/HOWTO.md index 55cddb1561..eb636aa419 100644 --- a/HOWTO.md +++ b/HOWTO.md @@ -238,7 +238,6 @@ attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dr # "multi_head" will be responsible for the forward multi_head = ( MultiHeadDispatch( - seq_len=SEQ, dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, diff --git a/xformers/benchmarks/benchmark_multi_head_dispatch.py b/xformers/benchmarks/benchmark_multi_head_dispatch.py new file mode 100644 index 0000000000..3bcc1ebe43 --- /dev/null +++ b/xformers/benchmarks/benchmark_multi_head_dispatch.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict + +import torch +import torch.nn as nn +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components import MultiHeadDispatch +from xformers.components.attention import ScaledDotProduct + +SHAPES = [ + (8, 384, 128), + (8, 784, 512), + (4, 1024, 768), + (4, 2048, 1024), + (2, 2048, 2048), + (2, 2048, 4096), + (2, 4096, 4096), + (1, 2048, 12288), +] + +N_HEADS = [4] + + +def bench_multihead_dispatch(backward: bool): + device = torch.device("cuda") + bw = "+bw" if backward else "" + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, M, K in SHAPES: + for heads in N_HEADS: + xf_multi_head = MultiHeadDispatch( + dim_model=K, + residual_dropout=0.0, + num_heads=heads, + attention=ScaledDotProduct(), + bias=True, + ).to(device=device, dtype=dtype) + torch_multi_head = nn.MultiheadAttention( + embed_dim=K, num_heads=heads, batch_first=True + ).to(device=device, dtype=dtype) + + query = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + + def torch_mha(): + y, _ = torch_multi_head(query=query, key=query, value=query) + if backward: + torch.norm(y).backward() + return y + + def xformers_mha(): + y = xf_multi_head(query=query, key=query, value=query) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase(torch_mha, f"torch - fw{bw}"), + TestCase(xformers_mha, f"xf - fw{bw}"), + ]: + time = triton.testing.do_bench(testcase.function)[0] + key = f"B={B}, M={M}, K={K}, N_HEADS={heads}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {dtype} --- ", + units="runtime in ms, lower is better", + ) + pretty_plot( + results, + title=f"MHA-FW{bw}-{dtype}", + units="runtime in ms, lower is better", + dash_key="torch", + ) + + +for bw in [False, True]: + bench_multihead_dispatch(bw) diff --git a/xformers/benchmarks/benchmark_revnet.py b/xformers/benchmarks/benchmark_revnet.py index 512a861b28..8561481dd4 100644 --- a/xformers/benchmarks/benchmark_revnet.py +++ b/xformers/benchmarks/benchmark_revnet.py @@ -75,7 +75,7 @@ def reversible_step(): results, title=f"RevNet-FW{bw}-{dtype}", units="runtime in ms, lower is better", - dash_key="pytorch", + dash_key="torch", )