Skip to content

Commit

Permalink
[benchmark] append matmul benchmark scirpts (apache#7)
Browse files Browse the repository at this point in the history
* remove redundant transparency

* dsl benchmark scirpts
  • Loading branch information
LeiWang1999 authored Feb 28, 2024
1 parent 0ecada8 commit f6cb733
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 378 deletions.
296 changes: 0 additions & 296 deletions Transparency.md

This file was deleted.

77 changes: 28 additions & 49 deletions benchmark/dsl/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags
from bitblas.gpu import Matmul
from bitblas.base.utils import apply_and_build
from bitblas.ops.impl.matmul_impl import matmul_nt_propagate_a_propagate_b
import time


Expand Down Expand Up @@ -93,54 +94,8 @@ def main(a: T.handle, b: T.handle, c: T.handle):

return MyModule


def matmul_nt_propagate_a_b(M, N, K, in_dtype="float16", out_dtype="float16"):
wm, wn, wk = 16, 16, 16
if in_dtype == "int8":
wm, wn, wk = 16, 16, 32

@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle):
T.func_attr(
{
"global_symbol": "main",
"tir.noalias": True,
"smooth_a": True,
"smooth_b": True,
}
)
A = T.match_buffer(a, [M // wm, K // wk, wm, wk], dtype=in_dtype)
B = T.match_buffer(b, [N // wn, K // wk, wn, wk], dtype=in_dtype)
C = T.match_buffer(c, [M, N], dtype=out_dtype)
A_reindex = T.alloc_buffer([M, K], dtype=in_dtype)
B_reindex = T.alloc_buffer([N, K], dtype=in_dtype)

for i, k in T.grid(M, K):
with T.block("A_reindex"):
vj, vk = T.axis.remap("SS", [i, k])
A_reindex[vj, vk] = A[vj // wm, vk // wk, vj % wm, vk % wk]

for j, k in T.grid(N, K):
with T.block("B_reindex"):
vj, vk = T.axis.remap("SS", [j, k])
B_reindex[vj, vk] = B[vj // wn, vk // wk, vj % wn, vk % wk]

for i, j, k in T.grid(M, N, K):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = tvm.tir.const(0, out_dtype)
C[vi, vj] = C[vi, vj] + A_reindex[vi, vk].astype(
out_dtype
) * B_reindex[vj, vk].astype(out_dtype)

return MyModule


# fmt:off
benchmark_sets = [
typical_test_shapes = [
# (prim_func, input_args, default_dlight_schedule),
(matmul_nt, (1024, 1024, 1024, "float16", "float16"), Matmul),
(matmul_nt, (16, 8192, 8192, "float16", "float16"), Matmul),
Expand All @@ -152,11 +107,35 @@ def main(a: T.handle, b: T.handle, c: T.handle):
(matmul_nn, (16384, 16384, 16384, "float16", "float16"), Matmul),
(matmul_nt, (1024, 1024, 1024, "float32", "float32"), Matmul),
(matmul_nt_propagate_b_f16_f16_mma, (16384, 16384, 16384), Matmul),
(matmul_nt_propagate_a_b, (16384, 16384, 16384, "int8", "int32"), Matmul),
(matmul_nt_propagate_a_b, (16384, 16384, 16384, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "int8", "int32", "int32"), Matmul),
(matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16"), Matmul),
]
# fmt:on

llm_shapes = [
# square test
(matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16"), Matmul),
# BLOOM-176B
(matmul_nt_propagate_a_propagate_b, (8192, 43008, 14336, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 14336, 14336, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 57344, 14336, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 14336, 57344, "float16", "float16"), Matmul),
# # OPT-65B
(matmul_nt_propagate_a_propagate_b, (8192, 9216, 9216, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 36864, 9216, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 9216, 36864, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 22016, 8192, "float16", "float16"), Matmul),
# # LLAMA-70B/65B
(matmul_nt_propagate_a_propagate_b, (8192, 8192, 22016, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 8192, 8192, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 28672, 8192, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 8192, 28672, "float16", "float16"), Matmul),
]

benchmark_sets = []
benchmark_sets.extend(llm_shapes)


benchmark_results = {}
for get_prim_func, input_args, d_schedule in benchmark_sets:
ir_module = get_prim_func(*input_args)
Expand Down
Loading

0 comments on commit f6cb733

Please sign in to comment.