Skip to content

Commit

Permalink
[Tensorize][runtime] Add support for AMX(Advanced Matrix Extensions) …
Browse files Browse the repository at this point in the history
…through Tensor intrinsics (#13642)

* add AMX config functions and building option.

* amx tensor intrinsics and u8s8s32 matmul testcase

* add int8 dense kernel use amx tensorize

* add int8 dense kernel use amx tensorize

* add amx init() and config() for dense test case

* correct the amx config

* fix lint.

* fix dense schedule

* remove operation of signal stack

* fix nit

* unified amx and vnni compute, remove dup one

* fix lint

* adopt to x86 int8 dense compute method;

* Revert "adopt to x86 int8 dense compute method;"

This reverts commit 5718a05.

* restore schedule ruls specially for ms dense_vnni

* add vnni ms target attributes

* remove the misoperations

* Revert "restore schedule ruls specially for ms dense_vnni"

This reverts commit 2bda03e.

* add vnni ms target attributes and remove misops

* Revert "add vnni ms target attributes"

This reverts commit c2e9f26.

* remove the misops
  • Loading branch information
Qianshui-Jiang authored Jan 5, 2023
1 parent bf0607b commit 07a5a9e
Show file tree
Hide file tree
Showing 14 changed files with 732 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
# Contrib library options
tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF)
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_AMX "Enable Intel AMX" OFF)
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
Expand Down Expand Up @@ -497,6 +498,7 @@ include(cmake/modules/contrib/EthosU.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/AMX.cmake)
include(cmake/modules/contrib/CUTLASS.cmake)
include(cmake/modules/contrib/ExampleTargetHooks.cmake)
include(cmake/modules/contrib/Random.cmake)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ set(USE_MKL OFF)
# - OFF: Disable DNNL
set(USE_DNNL OFF)

# Whether use Intel AMX instructions.
set(USE_AMX OFF)

# Whether use OpenMP thread pool, choices: gnu, intel
# Note: "gnu" uses gomp library, "intel" uses iomp5 library
set(USE_OPENMP none)
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_CUDNN="${USE_CUDNN}"
TVM_INFO_USE_CUSTOM_LOGGING="${USE_CUSTOM_LOGGING}"
TVM_INFO_USE_CUTLASS="${USE_CUTLASS}"
TVM_INFO_USE_AMX="${USE_AMX}"
TVM_INFO_USE_DNNL="${USE_DNNL}"
TVM_INFO_USE_ETHOSN="${USE_ETHOSN}"
TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}"
Expand Down
23 changes: 23 additions & 0 deletions cmake/modules/contrib/AMX.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_AMX)
file(GLOB AMX_RUNTIME_CONFIG src/runtime/contrib/amx/amx_config.cc)
list(APPEND COMPILER_SRCS ${AMX_RUNTIME_CONFIG})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids")
message(STATUS "Build with Intel AMX support...")
endif()
9 changes: 4 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,18 +591,17 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
"""dense_pack x86 strategy"""
strategy = _op.OpStrategy()

if (
inputs[0].dtype == "uint8"
and inputs[1].dtype == "int8"
and out_type.dtype == "int32"
and attrs["weight_layout"] == "NC16n4c"
):
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_vnni),
wrap_topi_schedule(topi.x86.schedule_dense_vnni),
name="dense_vnni.x86",
plevel=12,
wrap_compute_dense(topi.x86.dense_int8),
wrap_topi_schedule(topi.x86.schedule_dense_int8),
name="dense_int8.x86",
plevel=13,
)
else:
strategy.add_implementation(
Expand Down
167 changes: 143 additions & 24 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable
# pylint: disable=no-value-for-parameter
# pylint: disable=invalid-name,too-many-locals,unused-argument
# pylint: disable=no-value-for-parameter,unused-variable
"""x86 dense operators"""
from __future__ import absolute_import as _abs

Expand All @@ -27,7 +27,9 @@
from .. import generic, tag
from ..utils import get_const_tuple, traverse_inline
from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake
from .utils import get_simd_32bit_lanes
from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids
from .tensor_intrin import acc_32x32_int32_sapphirerapids
from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx


def _schedule_dense_pack_template(cfg, s, C, O):
Expand Down Expand Up @@ -278,11 +280,45 @@ def _callback(op):
return s


def dense_vnni_compute(cfg, X, packed_w, bias=None):
@autotvm.register_topi_compute("dense_int8.x86")
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Compute for uint8 x int8 -> int32 dense"""
if out_dtype is None:
out_dtype = data.dtype
assert len(weight.shape) == 4
assert data.dtype == "uint8" and weight.dtype == "int8"
_, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
assert n_inner == 16 and k_inner == 4
return dense_int8_compute(cfg, data, weight, bias)


@autotvm.register_topi_schedule("dense_int8.x86")
def schedule_dense_int8(cfg, outs):
"""Create a schedule for dense__int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

def _callback(op):
if "dense_int8" in op.tag:
if target_has_amx(mcpu):
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
elif target_has_vnni(mcpu):
dense_vnni_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


def dense_int8_compute(cfg, X, packed_w, bias=None):
"""Compute for uint8 x int8 -> int32 dense"""
m, k = X.shape
n_o, _, n_i, _ = packed_w.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_vnni(mcpu):
target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"}
else:
target_attr = None

C = te.compute(
(m, n_o * n_i),
Expand All @@ -293,16 +329,13 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None):
),
axis=ak,
),
tag="dense_vnni",
attrs={"schedule_rule": "dense_vnni"},
tag="dense_int8",
attrs=target_attr,
)

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)

a_y, _ = C.op.axis
cfg.define_split("tile_y", a_y, num_outputs=2)

return C


Expand All @@ -317,6 +350,7 @@ def split_y(out):
if cfg.is_fallback:
return s[out].split(a_y, factor=default_y_split_factor)

cfg.define_split("tile_y", a_y, num_outputs=2)
return cfg["tile_y"].apply(s, out, a_y)

(a_k,) = C.op.reduce_axis
Expand Down Expand Up @@ -348,26 +382,111 @@ def split_y(out):
return s, fused


@autotvm.register_topi_compute("dense_vnni.x86")
def dense_vnni(cfg, data, weight, bias=None, out_dtype=None):
"""Compute for uint8 x int8 -> int32 dense"""
if out_dtype is None:
out_dtype = data.dtype
assert len(weight.shape) == 4
assert data.dtype == "uint8" and weight.dtype == "int8"
_, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
assert n_inner == 16 and k_inner == 4
return dense_vnni_compute(cfg, data, weight, bias)
def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True):
"""Schedule dense compute using AMX TMUL instruction"""
# C: The output of GEMM
# O: The output of the fused op
def split_x(out):
default_x_split_factor1 = 32
default_x_split_factor2 = 2
default_x_split_factor3 = 2
default_x_split_factor4 = 2
a_x = s[out].op.axis[-2]

if cfg.is_fallback:
a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1)
a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2)
a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3)
a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4)
return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi]

cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: x.size[-1] == 32)
return cfg["tile_x"].apply(s, out, a_x)

def split_y(out):
default_y_split_factor1 = 32
default_y_split_factor2 = 4
default_y_split_factor3 = 4
default_y_split_factor4 = 4
a_y = s[out].op.axis[-1]

if cfg.is_fallback:
a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1)
a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2)
a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3)
a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4)
return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo]

cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: y.size[-1] == 32)
return cfg["tile_y"].apply(s, out, a_y)

def split_k(out, rd_axis):
default_k_split_factor1 = 128
default_k_split_factor2 = 2
default_k_split_factor3 = 2
default_k_split_factor4 = 2

if cfg.is_fallback:
a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1)
a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2)
a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3)
a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4)
return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki]

cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128)
return cfg["tile_k"].apply(s, out, rd_axis)

a_x, a_y = C.op.axis
(a_k,) = C.op.reduce_axis
CF = s.cache_write(C, "amx.tmm")

a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C)
a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C)
s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi)

s[CF].compute_at(s[C], a_yo)

(a_k_f,) = CF.op.reduce_axis
a_x_f, a_y_f = CF.op.axis

a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32)

a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32)
a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f)
s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f)

(m, k) = CF.op.input_tensors[0].shape
(n, c, n_i, c_i) = CF.op.input_tensors[1].shape
n = n * n_i

s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k)))
s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n)))

if C == O:
fused = s[O].fuse(a_x3, a_y3)
else:
a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O)
a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O)

s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi)
s[O].vectorize(a_xi)

fused = s[O].fuse(a_x3, a_y3)

if do_parallel:
s[O].parallel(fused)

return s, fused


@autotvm.register_topi_schedule("dense_vnni.x86")
def schedule_dense_vnni(cfg, outs):
"""Create a schedule for dense_vnni"""
@autotvm.register_topi_schedule("dense_amx_int8.x86")
def schedule_dense_amx_int8(cfg, outs):
"""Create a schedule for dense_amx_int8"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "dense_vnni" in op.tag:
dense_vnni_schedule(cfg, s, op.output(0), outs[0])
if "dense_amx_int8" in op.tag:
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
from ..utils import get_const_tuple
from ..nn import dense_alter_layout
from .utils import target_has_vnni
from .utils import target_has_amx
from .. import nn


def check_vnni_applicable(x, y, allow_padding=False):
def check_inst_applicable(x, y, allow_padding=False):
mcpu = tvm.target.Target.current().mcpu
simd_avai = target_has_vnni(mcpu) or target_has_amx(mcpu)
return (
target_has_vnni(mcpu)
simd_avai
and "int8" in x.dtype
and "int8" in y.dtype
and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0))
Expand All @@ -47,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
M, K = get_const_tuple(data_tensor.shape)
N, _ = get_const_tuple(weight_tensor.shape)

if check_vnni_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8":
if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8":
weight_layout = "NC16n4c"
return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype)

Expand Down Expand Up @@ -87,7 +89,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
"""Legalizes s8, s8 -> s32 GEMM op for VNNI."""
if (
check_vnni_applicable(arg_types[0], arg_types[1], allow_padding=True)
check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True)
and arg_types[0].dtype == "int8"
):
x, y = inputs
Expand Down
Loading

0 comments on commit 07a5a9e

Please sign in to comment.