forked from zhuzilin/ring-flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbenchmark_qkvpacked_func.py
74 lines (65 loc) · 2.29 KB
/
benchmark_qkvpacked_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from flash_attn import flash_attn_qkvpacked_func
import torch
import torch.distributed as dist
from ring_flash_attn import (
ring_flash_attn_qkvpacked_func,
ring_flash_attn_qkvpacked_func_v2,
zigzag_ring_flash_attn_qkvpacked_func,
stripe_flash_attn_qkvpacked_func,
)
import torch.cuda
def benchmark_forward(f, num_benchmark_iter=1000, log=True):
torch.cuda.empty_cache()
dtype = torch.bfloat16
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
batch_size = 1
seqlen = 1024 * 8
nheads = 5
d = 128
dropout_p = 0
causal = True
deterministic = False
assert seqlen % (2 * world_size) == 0
assert d % 8 == 0
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
begin = torch.cuda.Event(enable_timing=True)
begin.record()
with torch.no_grad():
for _ in range(num_benchmark_iter):
_ = f(
qkv,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
end = torch.cuda.Event(enable_timing=True)
end.record()
torch.cuda.synchronize(device=device)
time = begin.elapsed_time(end) / 1000.0
if rank == 0 and log:
print(f"{f.__name__} {num_benchmark_iter / time} iter/s, {time} sec")
if __name__ == "__main__":
dist.init_process_group("nccl")
rank = dist.get_rank()
if rank == 0:
print("warmuping...")
benchmark_forward(flash_attn_qkvpacked_func, log=False)
benchmark_forward(ring_flash_attn_qkvpacked_func, log=False)
benchmark_forward(ring_flash_attn_qkvpacked_func_v2, log=False)
benchmark_forward(stripe_flash_attn_qkvpacked_func, log=False)
benchmark_forward(zigzag_ring_flash_attn_qkvpacked_func, log=False)
if rank == 0:
print("benchmark:")
benchmark_forward(flash_attn_qkvpacked_func)
benchmark_forward(ring_flash_attn_qkvpacked_func)
benchmark_forward(ring_flash_attn_qkvpacked_func_v2)
benchmark_forward(stripe_flash_attn_qkvpacked_func)
benchmark_forward(zigzag_ring_flash_attn_qkvpacked_func)