Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensorize][runtime] Add support for AMX(Advanced Matrix Extensions) through Tensor intrinsics #13642

Merged
merged 22 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3230886
add AMX config functions and building option.
Qianshui-Jiang Dec 19, 2022
2312242
amx tensor intrinsics and u8s8s32 matmul testcase
Qianshui-Jiang Dec 19, 2022
3e2fc4e
add int8 dense kernel use amx tensorize
Qianshui-Jiang Dec 19, 2022
3f19099
add int8 dense kernel use amx tensorize
Qianshui-Jiang Dec 19, 2022
c53c394
add amx init() and config() for dense test case
Qianshui-Jiang Dec 19, 2022
98b9a23
Merge branch 'amx_int8_dev' of https://github.com/Qianshui-Jiang/tvm …
Qianshui-Jiang Dec 19, 2022
79d6636
correct the amx config
Qianshui-Jiang Dec 19, 2022
b866673
fix lint.
Qianshui-Jiang Dec 19, 2022
48fa37e
fix dense schedule
Qianshui-Jiang Dec 19, 2022
dd1eb24
remove operation of signal stack
Qianshui-Jiang Dec 23, 2022
73f45ef
fix nit
Qianshui-Jiang Dec 25, 2022
b921052
unified amx and vnni compute, remove dup one
Qianshui-Jiang Dec 29, 2022
e749360
fix lint
Qianshui-Jiang Dec 29, 2022
5718a05
adopt to x86 int8 dense compute method;
Qianshui-Jiang Dec 31, 2022
581331a
Revert "adopt to x86 int8 dense compute method;"
Qianshui-Jiang Jan 1, 2023
2bda03e
restore schedule ruls specially for ms dense_vnni
Qianshui-Jiang Jan 1, 2023
c2e9f26
add vnni ms target attributes
Qianshui-Jiang Jan 4, 2023
4469fd9
remove the misoperations
Qianshui-Jiang Jan 5, 2023
f763d52
Revert "restore schedule ruls specially for ms dense_vnni"
Qianshui-Jiang Jan 5, 2023
1f59aff
add vnni ms target attributes and remove misops
Qianshui-Jiang Jan 5, 2023
383d0b2
Revert "add vnni ms target attributes"
Qianshui-Jiang Jan 5, 2023
9422363
remove the misops
Qianshui-Jiang Jan 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
# Contrib library options
tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF)
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_AMX "Enable Intel AMX" OFF)
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
Expand Down Expand Up @@ -495,6 +496,7 @@ include(cmake/modules/contrib/EthosU.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/AMX.cmake)
include(cmake/modules/contrib/CUTLASS.cmake)
include(cmake/modules/contrib/ExampleTargetHooks.cmake)
include(cmake/modules/contrib/Random.cmake)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ set(USE_MKL OFF)
# - OFF: Disable DNNL
set(USE_DNNL OFF)

# Whether use Intel AMX instructions.
set(USE_AMX OFF)

# Whether use OpenMP thread pool, choices: gnu, intel
# Note: "gnu" uses gomp library, "intel" uses iomp5 library
set(USE_OPENMP none)
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_CUDNN="${USE_CUDNN}"
TVM_INFO_USE_CUSTOM_LOGGING="${USE_CUSTOM_LOGGING}"
TVM_INFO_USE_CUTLASS="${USE_CUTLASS}"
TVM_INFO_USE_AMX="${USE_AMX}"
TVM_INFO_USE_DNNL="${USE_DNNL}"
TVM_INFO_USE_ETHOSN="${USE_ETHOSN}"
TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}"
Expand Down
23 changes: 23 additions & 0 deletions cmake/modules/contrib/AMX.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_AMX)
file(GLOB AMX_RUNTIME_CONFIG src/runtime/contrib/amx/amx_config.cc)
list(APPEND COMPILER_SRCS ${AMX_RUNTIME_CONFIG})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids")
message(STATUS "Build with Intel AMX support...")
endif()
9 changes: 4 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,18 +591,17 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
"""dense_pack x86 strategy"""
strategy = _op.OpStrategy()

if (
inputs[0].dtype == "uint8"
and inputs[1].dtype == "int8"
and out_type.dtype == "int32"
and attrs["weight_layout"] == "NC16n4c"
):
masahi marked this conversation as resolved.
Show resolved Hide resolved
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_vnni),
wrap_topi_schedule(topi.x86.schedule_dense_vnni),
name="dense_vnni.x86",
plevel=12,
wrap_compute_dense(topi.x86.dense_int8),
wrap_topi_schedule(topi.x86.schedule_dense_int8),
name="dense_int8.x86",
plevel=13,
)
else:
strategy.add_implementation(
Expand Down
162 changes: 138 additions & 24 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable
# pylint: disable=no-value-for-parameter
# pylint: disable=invalid-name,too-many-locals,unused-argument
# pylint: disable=no-value-for-parameter,unused-variable
"""x86 dense operators"""
from __future__ import absolute_import as _abs

Expand All @@ -27,7 +27,9 @@
from .. import generic, tag
from ..utils import get_const_tuple, traverse_inline
from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake
from .utils import get_simd_32bit_lanes
from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids
from .tensor_intrin import acc_32x32_int32_sapphirerapids
from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx


def _schedule_dense_pack_template(cfg, s, C, O):
Expand Down Expand Up @@ -278,7 +280,36 @@ def _callback(op):
return s


def dense_vnni_compute(cfg, X, packed_w, bias=None):
@autotvm.register_topi_compute("dense_int8.x86")
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
"""Compute for uint8 x int8 -> int32 dense"""
if out_dtype is None:
out_dtype = data.dtype
assert len(weight.shape) == 4
assert data.dtype == "uint8" and weight.dtype == "int8"
_, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
assert n_inner == 16 and k_inner == 4
return dense_int8_compute(cfg, data, weight, bias)


@autotvm.register_topi_schedule("dense_int8.x86")
def schedule_dense_int8(cfg, outs):
"""Create a schedule for dense__int8"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

def _callback(op):
if "dense_int8" in op.tag:
if target_has_amx(mcpu):
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
elif target_has_vnni(mcpu):
dense_vnni_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


def dense_int8_compute(cfg, X, packed_w, bias=None):
"""Compute for uint8 x int8 -> int32 dense"""
m, k = X.shape
n_o, _, n_i, _ = packed_w.shape
Expand All @@ -293,16 +324,13 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None):
),
axis=ak,
),
tag="dense_vnni",
attrs={"schedule_rule": "dense_vnni"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be removed, it is used here

register_func("meta_schedule.x86.dense_vnni", schedule_rule_dense_vnni)
.

Since this only affects MetaSchedule, you don't have to provide this value for AMX. So only when dense_int8_compute is called for VNNI, you need to provide this attribute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi Is this case only use the x86 int8 compute method and inject a particular TIR scheduling? Can we just change the attribute dense_vnni to dense_int8 which used here?

Copy link
Member

@masahi masahi Jan 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it is important that we specify that we use this compute for VNNI. If the schedule rule annotation only says "dense_int8", we don't know which intrinsic to tensorize this compute with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be removed, it is used here

register_func("meta_schedule.x86.dense_vnni", schedule_rule_dense_vnni)

.
Since this only affects MetaSchedule, you don't have to provide this value for AMX. So only when dense_int8_compute is called for VNNI, you need to provide this attribute.

@masahi Sorry, may given the misunderstanding, I mean that can we use the dense_int8 in this test case? cuz here inject the VNNI intrinsic explicitly.
And by default in relay build flow, it will check if the VNNI or AMX is availiable and chose different schedulling.
I've verified that this test case still functional after the little modification in this commit bellow

Copy link
Member

@masahi masahi Jan 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course the test still works, because it was already written for VNNI. The point is that the name dense_vnni tells that the particular TE expression is meant for VNNI tensorization, in particular the weight is pre-packed appropriately. But a generic name like dense_int8 doesn't provide such information. Just because AMX can use the same layout doesn't mean we can use dense_int8. If MetaSchedule finds a compute annotated with dense_int8, it cannot tell if it should apply VNNI or AMX tensorization (if the latter is supported by MS in the future).

So please revert that commit and restore and pass attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"} when you create a compute expression for VNNI. schedule_rule is not relevant for AMX for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi yep, got it, the schedule rule for ms dense_vnni are restored.
Keep it remained but AMX not use it for now.
( A smalll question: How do we know inside of this compute expression if it's created for VNNI not AMX? )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest checking the target string in the op strategy, and create separate compute for VNNI or AMX (rather than using the same function, dense_int8, for both).

tag="dense_int8",
attrs={"schedule_rule": "dense_int8"},
)

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)

a_y, _ = C.op.axis
cfg.define_split("tile_y", a_y, num_outputs=2)

return C


Expand All @@ -317,6 +345,7 @@ def split_y(out):
if cfg.is_fallback:
return s[out].split(a_y, factor=default_y_split_factor)

cfg.define_split("tile_y", a_y, num_outputs=2)
return cfg["tile_y"].apply(s, out, a_y)

(a_k,) = C.op.reduce_axis
Expand Down Expand Up @@ -348,26 +377,111 @@ def split_y(out):
return s, fused


@autotvm.register_topi_compute("dense_vnni.x86")
def dense_vnni(cfg, data, weight, bias=None, out_dtype=None):
"""Compute for uint8 x int8 -> int32 dense"""
if out_dtype is None:
out_dtype = data.dtype
assert len(weight.shape) == 4
assert data.dtype == "uint8" and weight.dtype == "int8"
_, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim
assert n_inner == 16 and k_inner == 4
return dense_vnni_compute(cfg, data, weight, bias)
def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True):
"""Schedule dense compute using AMX TMUL instruction"""
# C: The output of GEMM
# O: The output of the fused op
def split_x(out):
default_x_split_factor1 = 32
default_x_split_factor2 = 2
default_x_split_factor3 = 2
default_x_split_factor4 = 2
a_x = s[out].op.axis[-2]

if cfg.is_fallback:
a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1)
a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2)
a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3)
a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4)
return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi]

cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: x.size[-1] == 32)
return cfg["tile_x"].apply(s, out, a_x)

def split_y(out):
default_y_split_factor1 = 32
default_y_split_factor2 = 4
default_y_split_factor3 = 4
default_y_split_factor4 = 4
a_y = s[out].op.axis[-1]

if cfg.is_fallback:
a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1)
a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2)
a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3)
a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4)
return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo]

cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: y.size[-1] == 32)
return cfg["tile_y"].apply(s, out, a_y)

def split_k(out, rd_axis):
default_k_split_factor1 = 128
default_k_split_factor2 = 2
default_k_split_factor3 = 2
default_k_split_factor4 = 2

if cfg.is_fallback:
a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1)
a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2)
a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3)
a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4)
return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki]

cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128)
return cfg["tile_k"].apply(s, out, rd_axis)

a_x, a_y = C.op.axis
(a_k,) = C.op.reduce_axis
CF = s.cache_write(C, "amx.tmm")

a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C)
a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C)
s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi)

s[CF].compute_at(s[C], a_yo)

(a_k_f,) = CF.op.reduce_axis
a_x_f, a_y_f = CF.op.axis

a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32)

a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32)
a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f)
s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f)

(m, k) = CF.op.input_tensors[0].shape
(n, c, n_i, c_i) = CF.op.input_tensors[1].shape
n = n * n_i

s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k)))
s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n)))

if C == O:
fused = s[O].fuse(a_x3, a_y3)
else:
a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O)
a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O)

s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi)
s[O].vectorize(a_xi)

fused = s[O].fuse(a_x3, a_y3)

if do_parallel:
s[O].parallel(fused)

return s, fused


@autotvm.register_topi_schedule("dense_vnni.x86")
def schedule_dense_vnni(cfg, outs):
"""Create a schedule for dense_vnni"""
@autotvm.register_topi_schedule("dense_amx_int8.x86")
def schedule_dense_amx_int8(cfg, outs):
"""Create a schedule for dense_amx_int8"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "dense_vnni" in op.tag:
dense_vnni_schedule(cfg, s, op.output(0), outs[0])
if "dense_amx_int8" in op.tag:
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s
Qianshui-Jiang marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
from ..utils import get_const_tuple
from ..nn import dense_alter_layout
from .utils import target_has_vnni
from .utils import target_has_amx
from .. import nn


def check_vnni_applicable(x, y, allow_padding=False):
def check_inst_applicable(x, y, allow_padding=False):
mcpu = tvm.target.Target.current().mcpu
simd_avai = target_has_vnni(mcpu) or target_has_amx(mcpu)
return (
target_has_vnni(mcpu)
simd_avai
and "int8" in x.dtype
and "int8" in y.dtype
and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0))
Expand All @@ -47,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
M, K = get_const_tuple(data_tensor.shape)
N, _ = get_const_tuple(weight_tensor.shape)

if check_vnni_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8":
if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8":
weight_layout = "NC16n4c"
return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype)

Expand Down Expand Up @@ -87,7 +89,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False):
"""Legalizes s8, s8 -> s32 GEMM op for VNNI."""
if (
check_vnni_applicable(arg_types[0], arg_types[1], allow_padding=True)
check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True)
and arg_types[0].dtype == "int8"
):
x, y = inputs
Expand Down
Loading