Skip to content

Commit

Permalink
[Tensorize][TOPI] Add AMX Tensorizing for int8 batch matmul (apache#1…
Browse files Browse the repository at this point in the history
…3745)

* amx int8 tensorized x86 bmm

* remove the unused amx schedule

* fix lint

* fix lint

* remove unused import

* fix Instr. assert in testcase.
  • Loading branch information
Qianshui-Jiang authored and fzi-peccia committed Mar 27, 2023
1 parent a13648f commit 9edabfe
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 35 deletions.
10 changes: 3 additions & 7 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.meta_schedule import is_meta_schedule_enabled
from tvm.relay.ty import is_dynamic
from tvm.target import Target
from tvm.te import SpecializedCondition
from tvm.topi.x86.utils import target_has_vnni

from .. import op as _op
from .generic import *
Expand Down Expand Up @@ -618,24 +616,22 @@ def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
"""batch_matmul x86 strategy"""
strategy = _op.OpStrategy()
mcpu = Target.current().mcpu

need_auto_scheduler_layout = is_auto_scheduler_enabled()
need_meta_schedule_layout = is_meta_schedule_enabled()

if (
not attrs.transpose_a
and attrs.transpose_b
and target_has_vnni(mcpu)
and inputs[0].dtype == "uint8"
and inputs[1].dtype == "int8"
and inputs[1].shape[-2] % 16 == 0
and inputs[1].shape[-1] % 4 == 0
):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul_vnni_compute, need_out_dtype=True),
wrap_topi_schedule(topi.x86.schedule_batch_matmul_vnni),
name="batch_matmul_vnni.x86",
wrap_compute_batch_matmul(topi.x86.batch_matmul_int8_compute, need_out_dtype=True),
wrap_topi_schedule(topi.x86.schedule_batch_matmul_int8),
name="batch_matmul_int8.x86",
plevel=10,
)
elif is_dynamic(out_type) or need_auto_scheduler_layout or need_meta_schedule_layout:
Expand Down
53 changes: 42 additions & 11 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable
# pylint: disable=unused-argument
"""x86 batch_matmul operators"""
import tvm
from tvm import autotvm, te
Expand All @@ -24,18 +25,24 @@
from .. import generic, nn
from ..transform import layout_transform
from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline
from .dense import dense_vnni_schedule
from .dense import dense_vnni_schedule, dense_amx_int8_schedule
from .injective import schedule_injective_from_existing
from .utils import target_has_vnni, target_has_amx


@autotvm.register_topi_compute("batch_matmul_vnni.x86")
def batch_matmul_vnni_compute(cfg, x, y, *_):
def batch_matmul_int8_compute(cfg, x, y, *_):
"""Compute for uint8 x int8 -> int32 batch_matmul"""
batch, m, k = x.shape
packed_y_layout = "BNK16n4k"
packed_y = layout_transform(y, "BNK", packed_y_layout)
_, n_o, _, n_i, _ = packed_y.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_vnni(mcpu):
attrs_info = {"schedule_rule": "batch_matmul_vnni"}
else:
attrs_info = None

z = te.compute(
(batch, m, n_o * n_i),
Expand All @@ -46,14 +53,10 @@ def batch_matmul_vnni_compute(cfg, x, y, *_):
),
axis=ak,
),
tag="batch_matmul_vnni",
attrs={"schedule_rule": "batch_matmul_vnni"},
tag="batch_matmul_int8",
attrs=attrs_info,
)

_, a_y, _ = z.op.axis
cfg.define_split("tile_y", a_y, num_outputs=2)
cfg.define_knob("layout_trans_compute_root", [0, 1])

return z


Expand All @@ -67,6 +70,7 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
# Parallelize over batch
fused = s[O].fuse(O.op.axis[0], fused_inner)
s[O].parallel(fused)
cfg.define_knob("layout_trans_compute_root", [0, 1])

if cfg["layout_trans_compute_root"].val:
s[layout_trans].compute_root()
Expand All @@ -80,6 +84,29 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
return s


def batch_matmul_amx_schedule(cfg, s, C, O, layout_trans):
"""Schedule batch_matmul compute using AMX tdpbusd instruction"""
# C: The output of batched GEMM
# O: The output of the fused op

# Schedule the GEMM part
s, fused_inner = dense_amx_int8_schedule(cfg, s, C, O, do_parallel=False)
# Parallelize over ouuter loop
fused = s[O].fuse(O.op.axis[0], fused_inner)
s[O].parallel(fused)
cfg.define_knob("layout_trans_compute_root", [0, 1])

if cfg["layout_trans_compute_root"].val:
s[layout_trans].compute_root()
schedule_injective_from_existing(s, layout_trans)
else:
_, _, _, ni, ki = s[layout_trans].op.axis
s[layout_trans].vectorize(ki)
s[layout_trans].unroll(ni)

return s


@autotvm.register_topi_compute("batch_matmul.x86")
def batch_matmul(
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
Expand Down Expand Up @@ -202,14 +229,18 @@ def _callback(op):


@autotvm.register_topi_schedule("batch_matmul_vnni.x86")
def schedule_batch_matmul_vnni(cfg, outs):
def schedule_batch_matmul_int8(cfg, outs):
"""Schedule for batch_matmul_vnni"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

def _callback(op):
if "batch_matmul_vnni" in op.tag:
if "batch_matmul_int8" in op.tag:
layout_trans = op.input_tensors[1]
batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans)
if target_has_amx(mcpu):
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans)
elif target_has_vnni(mcpu):
batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans)

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down
21 changes: 4 additions & 17 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def split_k(out, rd_axis):
cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128)
return cfg["tile_k"].apply(s, out, rd_axis)

a_x, a_y = C.op.axis
a_x, a_y = C.op.axis[-2:]
(a_k,) = C.op.reduce_axis
CF = s.cache_write(C, "amx.tmm")

Expand All @@ -447,16 +447,16 @@ def split_k(out, rd_axis):
s[CF].compute_at(s[C], a_yo)

(a_k_f,) = CF.op.reduce_axis
a_x_f, a_y_f = CF.op.axis
a_x_f, a_y_f = CF.op.axis[-2:]

a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32)

a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32)
a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f)
s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f)

(m, k) = CF.op.input_tensors[0].shape
(n, c, n_i, c_i) = CF.op.input_tensors[1].shape
(m, k) = CF.op.input_tensors[0].shape[-2:]
(n, c, n_i, c_i) = CF.op.input_tensors[1].shape[-4:]
n = n * n_i

s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k)))
Expand All @@ -479,19 +479,6 @@ def split_k(out, rd_axis):
return s, fused


@autotvm.register_topi_schedule("dense_amx_int8.x86")
def schedule_dense_amx_int8(cfg, outs):
"""Create a schedule for dense_amx_int8"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "dense_amx_int8" in op.tag:
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib):
"""Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(tensor_a.shape)
Expand Down
55 changes: 55 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,61 @@ def test_batch_matmul_vnni(b, m, n, k):
np.testing.assert_equal(out, ref)


@pytest.mark.skip("skip due to AMX feature not avaliable yet")
@pytest.mark.parametrize(
"b,m,n,k",
[
(16, 32, 32, 128),
(16, 32, 32, 127),
(16, 32, 31, 128),
],
)
def test_batch_matmul_amx(b, m, n, k):
amx_init = tvm.get_global_func("runtime.amx_init")
amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig")
assert amx_init()
assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns.

x_shape = (b, m, k)
y_shape = (b, n, k)
z_shape = (b, m, n)

for lhs_dtype in ["uint8", "int8"]:
x = relay.var("x", shape=x_shape, dtype=lhs_dtype)
y = relay.var("y", shape=y_shape, dtype="int8")
z = relay.var("z", shape=z_shape, dtype="int32")
bmm = relay.nn.batch_matmul(x, y, out_dtype="int32")
out = bmm + z
mod = tvm.IRModule.from_expr(out)

target = "llvm -mcpu=sapphirerapids"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)

asm = lib.lib.get_source("asm")
assert "tilezero" in asm
assert "tileloaddt1" in asm
assert "tdpbusd" in asm
assert "tilestored" in asm

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

x_np = np.random.uniform(1, 10, size=x_shape).astype(lhs_dtype)
y_np = np.random.uniform(1, 10, size=y_shape).astype("int8")
z_np = np.random.uniform(1, 10, size=z_shape).astype("int32")

runtime.set_input("x", x_np)
runtime.set_input("y", y_np)
runtime.set_input("z", z_np)
runtime.run()

out = runtime.get_output(0).numpy()
ref = tvm.topi.testing.batch_matmul(x_np, y_np, out_dtype="int32") + z_np

np.testing.assert_equal(out, ref)


@pytest.mark.skip("Requires GFX10 AMDGPU")
def test_batch_matmul_rocm_sdot4():
x_shape = (16, 32, 96)
Expand Down

0 comments on commit 9edabfe

Please sign in to comment.