From 235cd83ab115229ec20bdf1f520a5f82d4300707 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 13 Apr 2022 11:42:43 -0700 Subject: [PATCH 1/3] multihead dispatch benchmark --- HOWTO.md | 1 - .../benchmark_multi_head_dispatch.py | 92 +++++++++++++++++++ xformers/benchmarks/benchmark_revnet.py | 2 +- 3 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 xformers/benchmarks/benchmark_multi_head_dispatch.py diff --git a/HOWTO.md b/HOWTO.md index 55cddb1561..eb636aa419 100644 --- a/HOWTO.md +++ b/HOWTO.md @@ -238,7 +238,6 @@ attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dr # "multi_head" will be responsible for the forward multi_head = ( MultiHeadDispatch( - seq_len=SEQ, dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, diff --git a/xformers/benchmarks/benchmark_multi_head_dispatch.py b/xformers/benchmarks/benchmark_multi_head_dispatch.py new file mode 100644 index 0000000000..3bcc1ebe43 --- /dev/null +++ b/xformers/benchmarks/benchmark_multi_head_dispatch.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict + +import torch +import torch.nn as nn +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components import MultiHeadDispatch +from xformers.components.attention import ScaledDotProduct + +SHAPES = [ + (8, 384, 128), + (8, 784, 512), + (4, 1024, 768), + (4, 2048, 1024), + (2, 2048, 2048), + (2, 2048, 4096), + (2, 4096, 4096), + (1, 2048, 12288), +] + +N_HEADS = [4] + + +def bench_multihead_dispatch(backward: bool): + device = torch.device("cuda") + bw = "+bw" if backward else "" + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, M, K in SHAPES: + for heads in N_HEADS: + xf_multi_head = MultiHeadDispatch( + dim_model=K, + residual_dropout=0.0, + num_heads=heads, + attention=ScaledDotProduct(), + bias=True, + ).to(device=device, dtype=dtype) + torch_multi_head = nn.MultiheadAttention( + embed_dim=K, num_heads=heads, batch_first=True + ).to(device=device, dtype=dtype) + + query = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + + def torch_mha(): + y, _ = torch_multi_head(query=query, key=query, value=query) + if backward: + torch.norm(y).backward() + return y + + def xformers_mha(): + y = xf_multi_head(query=query, key=query, value=query) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase(torch_mha, f"torch - fw{bw}"), + TestCase(xformers_mha, f"xf - fw{bw}"), + ]: + time = triton.testing.do_bench(testcase.function)[0] + key = f"B={B}, M={M}, K={K}, N_HEADS={heads}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {dtype} --- ", + units="runtime in ms, lower is better", + ) + pretty_plot( + results, + title=f"MHA-FW{bw}-{dtype}", + units="runtime in ms, lower is better", + dash_key="torch", + ) + + +for bw in [False, True]: + bench_multihead_dispatch(bw) diff --git a/xformers/benchmarks/benchmark_revnet.py b/xformers/benchmarks/benchmark_revnet.py index 512a861b28..8561481dd4 100644 --- a/xformers/benchmarks/benchmark_revnet.py +++ b/xformers/benchmarks/benchmark_revnet.py @@ -75,7 +75,7 @@ def reversible_step(): results, title=f"RevNet-FW{bw}-{dtype}", units="runtime in ms, lower is better", - dash_key="pytorch", + dash_key="torch", ) From 683d706601342b3cd0b58356bab0e7421346f163 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 13 Apr 2022 12:53:16 -0700 Subject: [PATCH 2/3] benchmark non self attention case --- .../benchmark_multi_head_dispatch.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/xformers/benchmarks/benchmark_multi_head_dispatch.py b/xformers/benchmarks/benchmark_multi_head_dispatch.py index 3bcc1ebe43..7b332280e8 100644 --- a/xformers/benchmarks/benchmark_multi_head_dispatch.py +++ b/xformers/benchmarks/benchmark_multi_head_dispatch.py @@ -28,7 +28,7 @@ N_HEADS = [4] -def bench_multihead_dispatch(backward: bool): +def bench_multihead_dispatch(backward: bool, self_attention: bool): device = torch.device("cuda") bw = "+bw" if backward else "" @@ -48,18 +48,29 @@ def bench_multihead_dispatch(backward: bool): embed_dim=K, num_heads=heads, batch_first=True ).to(device=device, dtype=dtype) - query = torch.randn( + q = torch.randn( (B, M, K), requires_grad=backward, device=device, dtype=dtype ) + if self_attention: + k = q + v = q + else: + k = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + v = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + def torch_mha(): - y, _ = torch_multi_head(query=query, key=query, value=query) + y, _ = torch_multi_head(query=q, key=k, value=v) if backward: torch.norm(y).backward() return y def xformers_mha(): - y = xf_multi_head(query=query, key=query, value=query) + y = xf_multi_head(query=q, key=k, value=v) if backward: torch.norm(y).backward() return y @@ -89,4 +100,5 @@ def xformers_mha(): for bw in [False, True]: - bench_multihead_dispatch(bw) + for self_attention in [False, True]: + bench_multihead_dispatch(bw, self_attention) From 9f9ba67b0c538bfc9896357f8b952d14e7d38f1a Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 13 Apr 2022 13:01:38 -0700 Subject: [PATCH 3/3] include in table key --- xformers/benchmarks/benchmark_multi_head_dispatch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xformers/benchmarks/benchmark_multi_head_dispatch.py b/xformers/benchmarks/benchmark_multi_head_dispatch.py index 7b332280e8..e56f5fa6e4 100644 --- a/xformers/benchmarks/benchmark_multi_head_dispatch.py +++ b/xformers/benchmarks/benchmark_multi_head_dispatch.py @@ -31,6 +31,7 @@ def bench_multihead_dispatch(backward: bool, self_attention: bool): device = torch.device("cuda") bw = "+bw" if backward else "" + sa = " (self_attn)" if self_attention else "" for dtype in [torch.float16, torch.float32]: results: Dict[str, Any] = {} @@ -76,8 +77,8 @@ def xformers_mha(): return y for testcase in [ - TestCase(torch_mha, f"torch - fw{bw}"), - TestCase(xformers_mha, f"xf - fw{bw}"), + TestCase(torch_mha, f"torch - fw{bw}{sa}"), + TestCase(xformers_mha, f"xf - fw{bw}{sa}"), ]: time = triton.testing.do_bench(testcase.function)[0] key = f"B={B}, M={M}, K={K}, N_HEADS={heads}"