-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtriton_matmul.py
124 lines (115 loc) · 5.06 KB
/
triton_matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import triton.language as tl
import triton
@triton.heuristics({
'EVEN_K': lambda *args, **meta: False #lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0,
})
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
LOCKS,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z*BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K*SPLIT_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(tl.float16)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
LOCKS = LOCKS + tl.program_id(0)
COUNT = LOCKS + tl.num_programs(0)
while tl.atomic_cas(LOCKS, 0, 1) == 1:
pass
count = tl.load(COUNT)
if count == 0:
tl.store(C, acc, mask=mask)
else:
curr = tl.load(C, mask=mask, other=0.)
tl.store(C, acc + curr, mask=mask)
tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
tl.atomic_xchg(LOCKS, 0)
class _matmul(torch.autograd.Function):
kernel = _kernel
_locks = dict()
@staticmethod
def _call(a, b):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
# allocate locks for split-k
if a.device not in _matmul._locks:
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
locks = _matmul._locks[device]
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
locks, GROUP_M=8)
return c
@staticmethod
def forward(ctx, a, b):
return _matmul._call(a, b)
matmul_split_k = _matmul.apply