Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- small cleanup, dumping the 4 tiles dropout option
Browse files Browse the repository at this point in the history
- adding the new plots
blefaudeux committed Nov 13, 2022
1 parent 6a294d6 commit c8f656d
Showing 68 changed files with 91 additions and 61 deletions.
2 changes: 2 additions & 0 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
@@ -37,6 +37,8 @@ Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py

## Triton layers

Please not that as of November 2022 these layers are not optimized for typical production GPUs out there (not developed for some time and mostly tested on a laptop GPU), and that better performances are probably possible with some minor changes as proven in other libraries since xformers went out.

### Fused softmax

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_smelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_smelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_gelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_squared_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_gelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_squared_relu.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float16.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float32.png
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -27,4 +27,4 @@ hydra-core >= 1.1
fairscale >= 0.4.5

# Dependency for fused layers, optional
triton==2.0.0.dev20221014
triton==2.0.0.dev20221105
4 changes: 3 additions & 1 deletion tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,8 @@

if _triton_available:
try:
import triton

from xformers.triton import dropout as triton_dropout
from xformers.triton.dropout import FusedDropoutBias
from xformers.triton.utils import gpu_capabilities_older_than_70
@@ -130,7 +132,7 @@ def test_dropout(shape, amp, bias, p):
torch.cuda.manual_seed(0)
y_2 = triton_dropout(x, p=0.5)

assert torch.allclose(y_1, y_2)
triton.testing.assert_almost_equal(y_1, y_2)


@pytest.mark.skipif(not _gpu_available, reason="GPU is not available")
3 changes: 2 additions & 1 deletion tests/test_triton_fused_linear.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,8 @@ def test_fused_matmul(shape, dtype):
)

# Now check that adding an activation to the mix still produces valid results
# NOTE: SquaredReLU fails, some outlier representation issue
# NOTE: SquaredReLU fails, some outlier representation issue but the eyeballed results look reasonable
# could be due to a different accumulation out of the box (tf32 for instance)
for activation in filter(
lambda x: x not in (Activation.SquaredReLU, Activation.StarReLU), Activation
):
6 changes: 3 additions & 3 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw

BLOCK_M = 32
BLOCK_N = 128
BLOCK_N = 64 # NOTE: This should ideally be GPU dependent, big impact on perf


# Helper to handle the SPMD launch grid and error cases
@@ -36,7 +36,7 @@ def forward(ctx, x, p, bias, activation, trainable_bias):

def grid(meta):
return (
triton.cdiv(M, meta["BLOCK_M"]), # 4 x
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)

@@ -101,7 +101,7 @@ def backward(
# - over N we compromise in between trying to use as much memory paralellism as possible,
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
# big because of register spilling
N_BLOCKS_M = triton.cdiv(M, BLOCK_M) # 4x
N_BLOCKS_M = triton.cdiv(M, BLOCK_M)

if ctx.trainable_bias:
grad_bias = torch.empty(
38 changes: 10 additions & 28 deletions xformers/triton/k_dropout.py
Original file line number Diff line number Diff line change
@@ -31,9 +31,6 @@
triton.Config({}, num_warps=16),
]

MAX_INT32 = 2147483647
MAX_UINT32 = 4294967295


# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
@@ -66,7 +63,7 @@ def k_dropout_fw(
# fmt: on

row_id = tl.program_id(axis=0)
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) # 4x
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)

col_id = tl.program_id(axis=1)
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
@@ -106,20 +103,12 @@ def k_dropout_fw(
# get the random keep mask
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
seed_int = tl.load(SEEDS + col_id)
r = tl.rand(seed_int, rand_offsets)
keep_mask = r > p

if 1:
r = tl.rand(seed_int, rand_offsets)
keep_mask = r > p

# prune and normalize in one go
keep = tl.reshape(keep_mask, x.shape)
output = tl.where(keep, (x * p_scale).to(x.dtype), 0.)
else:
r0, r1, r2, r3 = tl.randint4x(seed_int, rand_offsets)
r = tl.cat(tl.cat(r0, r1), tl.cat(r2, r3))
r = r.to(tl.uint32, bitcast=True)
r = tl.reshape(r, x.shape)
output = tl.where(r > p * MAX_UINT32, x * p_scale, 0.)
# prune and normalize in one go
keep = tl.reshape(keep_mask, x.shape)
output = tl.where(keep, (x * p_scale).to(x.dtype), 0.)

tl.store(y_ptrs, output, mask=block_mask) # output

@@ -158,7 +147,7 @@ def k_dropout_bw(
# fmt: on

row_id = tl.program_id(axis=0)
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) # 4x
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)

col_id = tl.program_id(axis=1)
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
@@ -206,16 +195,9 @@ def k_dropout_bw(
# from the same seeds, so the same drop mask is applied here
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
seed_int = tl.load(SEEDS + col_id)
if 1:
r = tl.rand(seed_int, rand_offsets)
r = tl.reshape(r, grad_out.shape)
output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.)
else:
r0, r1, r2, r3 = tl.randint4x(seed_int, rand_offsets)
r = tl.cat(tl.cat(r0, r1), tl.cat(r2, r3))
r = r.to(tl.uint32, bitcast=True)
r = tl.reshape(r, inputs.shape)
output = tl.where(r > p * MAX_UINT32, grad_out * p_scale, 0.)
r = tl.rand(seed_int, rand_offsets)
r = tl.reshape(r, grad_out.shape)
output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.)

# write-back
tl.store(grad_in_ptrs, output, mask=block_mask)
18 changes: 8 additions & 10 deletions xformers/triton/k_fused_matmul_bw.py
Original file line number Diff line number Diff line change
@@ -18,18 +18,16 @@
squared_relu_grad,
star_relu_grad,
)
from xformers.triton.sum_strided import sum_2d_dim_0


# fmt: off
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 32}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_N": 64}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_N": 64}, num_stages=4, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=2),
triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=4),
],
key=["N"],
)
@@ -155,9 +153,9 @@ def fused_matmul_backward(
# just before the activation
grad_out_ = grad_act

# The following ops can also be handled by triton
grad_in = grad_out_ @ weight
# The following ops can also be handled by pytorch
grad_in = triton.ops.matmul(grad_out_, weight)
grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None
grad_bias = sum_2d_dim_0(grad_out_) if trainable_bias else None
grad_bias = torch.sum(grad_out_, dim=0) if trainable_bias else None

return grad_in.reshape_as(inputs), grad_weight, grad_bias
79 changes: 62 additions & 17 deletions xformers/triton/k_fused_matmul_fw.py
Original file line number Diff line number Diff line change
@@ -21,24 +21,64 @@
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications


def get_configs(block_k):
return [
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": block_k},
num_stages=4,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": block_k},
num_stages=4,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": block_k},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": block_k},
num_stages=3,
num_warps=8,
),
]


# fmt: off
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 16, "BLOCK_N": 16}, num_stages=5, num_warps=1),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_stages=5, num_warps=1),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64}, num_stages=3, num_warps=8),
],
configs=[c for block_k in [32, 64] for c in get_configs(block_k)],
key=["M", "N", "K"],
)
@triton.heuristics({
'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0,
})
@triton.jit
def kernel_fma(
# Pointers to matrices
@@ -53,6 +93,7 @@ def kernel_fma(
# Meta-parameters
BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_N: tl.constexpr,
BIAS: tl.constexpr,
SAVE_ACT_INPUTS: tl.constexpr,
ACTIVATION: tl.constexpr,
@@ -110,7 +151,10 @@ def kernel_fma(
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

if BIAS:
bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
if EVEN_N:
bias = tl.load(bias + rn).to(tl.float32)
else:
bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
acc += bias[None, :]

# block level matrix multiplication.
@@ -125,6 +169,9 @@ def kernel_fma(

acc += tl.dot(a, w)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

# optional: save the activation inputs
if SAVE_ACT_INPUTS:
act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :]
@@ -184,7 +231,6 @@ def fused_matmul(

# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
BLOCK_K = 32 if K < 1024 else 64

# fmt: off
kernel_fma[grid](
@@ -196,7 +242,6 @@ def fused_matmul(
ACTIVATION=activation, # optional fused activation
BIAS=bias is not None, # optional fused bias
GROUP_M=8, # speed optimization: group the programs
BLOCK_K=BLOCK_K,
SAVE_ACT_INPUTS=save_act_inputs,
is_fp16=x_.dtype == torch.float16
)

0 comments on commit c8f656d

Please sign in to comment.