diff --git a/CMakeLists.txt b/CMakeLists.txt index 119bf8325c8c..18b85a582279 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,6 +86,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson") # Contrib library options tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF) tvm_option(USE_BLAS "The blas library to be linked" none) +tvm_option(USE_AMX "Enable Intel AMX" OFF) tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) tvm_option(USE_DNNL "Enable DNNL codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) @@ -495,6 +496,7 @@ include(cmake/modules/contrib/EthosU.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) +include(cmake/modules/contrib/AMX.cmake) include(cmake/modules/contrib/CUTLASS.cmake) include(cmake/modules/contrib/ExampleTargetHooks.cmake) include(cmake/modules/contrib/Random.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index 679f5c459e87..3189815d237c 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -174,6 +174,9 @@ set(USE_MKL OFF) # - OFF: Disable DNNL set(USE_DNNL OFF) +# Whether use Intel AMX instructions. +set(USE_AMX OFF) + # Whether use OpenMP thread pool, choices: gnu, intel # Note: "gnu" uses gomp library, "intel" uses iomp5 library set(USE_OPENMP none) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 7c24088c0ad2..4fc7ed3262d5 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -65,6 +65,7 @@ function(add_lib_info src_file) TVM_INFO_USE_CUDNN="${USE_CUDNN}" TVM_INFO_USE_CUSTOM_LOGGING="${USE_CUSTOM_LOGGING}" TVM_INFO_USE_CUTLASS="${USE_CUTLASS}" + TVM_INFO_USE_AMX="${USE_AMX}" TVM_INFO_USE_DNNL="${USE_DNNL}" TVM_INFO_USE_ETHOSN="${USE_ETHOSN}" TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}" diff --git a/cmake/modules/contrib/AMX.cmake b/cmake/modules/contrib/AMX.cmake new file mode 100644 index 000000000000..ac349c4336a2 --- /dev/null +++ b/cmake/modules/contrib/AMX.cmake @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_AMX) + file(GLOB AMX_RUNTIME_CONFIG src/runtime/contrib/amx/amx_config.cc) + list(APPEND COMPILER_SRCS ${AMX_RUNTIME_CONFIG}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") + message(STATUS "Build with Intel AMX support...") +endif() diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 7ff4dbc0ad1b..4585809f63e1 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -591,7 +591,6 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def dense_pack_strategy_cpu(attrs, inputs, out_type, target): """dense_pack x86 strategy""" strategy = _op.OpStrategy() - if ( inputs[0].dtype == "uint8" and inputs[1].dtype == "int8" @@ -599,10 +598,10 @@ def dense_pack_strategy_cpu(attrs, inputs, out_type, target): and attrs["weight_layout"] == "NC16n4c" ): strategy.add_implementation( - wrap_compute_dense(topi.x86.dense_vnni), - wrap_topi_schedule(topi.x86.schedule_dense_vnni), - name="dense_vnni.x86", - plevel=12, + wrap_compute_dense(topi.x86.dense_int8), + wrap_topi_schedule(topi.x86.schedule_dense_int8), + name="dense_int8.x86", + plevel=13, ) else: strategy.add_implementation( diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 65a803781a57..ada19d598cdf 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,too-many-locals,unused-variable -# pylint: disable=no-value-for-parameter +# pylint: disable=invalid-name,too-many-locals,unused-argument +# pylint: disable=no-value-for-parameter,unused-variable """x86 dense operators""" from __future__ import absolute_import as _abs @@ -27,7 +27,9 @@ from .. import generic, tag from ..utils import get_const_tuple, traverse_inline from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake -from .utils import get_simd_32bit_lanes +from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids +from .tensor_intrin import acc_32x32_int32_sapphirerapids +from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx def _schedule_dense_pack_template(cfg, s, C, O): @@ -278,11 +280,45 @@ def _callback(op): return s -def dense_vnni_compute(cfg, X, packed_w, bias=None): +@autotvm.register_topi_compute("dense_int8.x86") +def dense_int8(cfg, data, weight, bias=None, out_dtype=None): + """Compute for uint8 x int8 -> int32 dense""" + if out_dtype is None: + out_dtype = data.dtype + assert len(weight.shape) == 4 + assert data.dtype == "uint8" and weight.dtype == "int8" + _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim + assert n_inner == 16 and k_inner == 4 + return dense_int8_compute(cfg, data, weight, bias) + + +@autotvm.register_topi_schedule("dense_int8.x86") +def schedule_dense_int8(cfg, outs): + """Create a schedule for dense__int8""" + s = te.create_schedule([x.op for x in outs]) + mcpu = tvm.target.Target.current().mcpu + + def _callback(op): + if "dense_int8" in op.tag: + if target_has_amx(mcpu): + dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) + elif target_has_vnni(mcpu): + dense_vnni_schedule(cfg, s, op.output(0), outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def dense_int8_compute(cfg, X, packed_w, bias=None): """Compute for uint8 x int8 -> int32 dense""" m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") + mcpu = tvm.target.Target.current().mcpu + if target_has_vnni(mcpu): + target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"} + else: + target_attr = None C = te.compute( (m, n_o * n_i), @@ -293,16 +329,13 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None): ), axis=ak, ), - tag="dense_vnni", - attrs={"schedule_rule": "dense_vnni"}, + tag="dense_int8", + attrs=target_attr, ) if bias is not None: C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) - a_y, _ = C.op.axis - cfg.define_split("tile_y", a_y, num_outputs=2) - return C @@ -317,6 +350,7 @@ def split_y(out): if cfg.is_fallback: return s[out].split(a_y, factor=default_y_split_factor) + cfg.define_split("tile_y", a_y, num_outputs=2) return cfg["tile_y"].apply(s, out, a_y) (a_k,) = C.op.reduce_axis @@ -348,26 +382,111 @@ def split_y(out): return s, fused -@autotvm.register_topi_compute("dense_vnni.x86") -def dense_vnni(cfg, data, weight, bias=None, out_dtype=None): - """Compute for uint8 x int8 -> int32 dense""" - if out_dtype is None: - out_dtype = data.dtype - assert len(weight.shape) == 4 - assert data.dtype == "uint8" and weight.dtype == "int8" - _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim - assert n_inner == 16 and k_inner == 4 - return dense_vnni_compute(cfg, data, weight, bias) +def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True): + """Schedule dense compute using AMX TMUL instruction""" + # C: The output of GEMM + # O: The output of the fused op + def split_x(out): + default_x_split_factor1 = 32 + default_x_split_factor2 = 2 + default_x_split_factor3 = 2 + default_x_split_factor4 = 2 + a_x = s[out].op.axis[-2] + + if cfg.is_fallback: + a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1) + a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2) + a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3) + a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4) + return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi] + + cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: x.size[-1] == 32) + return cfg["tile_x"].apply(s, out, a_x) + + def split_y(out): + default_y_split_factor1 = 32 + default_y_split_factor2 = 4 + default_y_split_factor3 = 4 + default_y_split_factor4 = 4 + a_y = s[out].op.axis[-1] + + if cfg.is_fallback: + a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1) + a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2) + a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3) + a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4) + return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo] + + cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: y.size[-1] == 32) + return cfg["tile_y"].apply(s, out, a_y) + + def split_k(out, rd_axis): + default_k_split_factor1 = 128 + default_k_split_factor2 = 2 + default_k_split_factor3 = 2 + default_k_split_factor4 = 2 + + if cfg.is_fallback: + a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1) + a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2) + a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3) + a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4) + return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki] + + cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128) + return cfg["tile_k"].apply(s, out, rd_axis) + + a_x, a_y = C.op.axis + (a_k,) = C.op.reduce_axis + CF = s.cache_write(C, "amx.tmm") + + a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C) + a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C) + s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi) + + s[CF].compute_at(s[C], a_yo) + + (a_k_f,) = CF.op.reduce_axis + a_x_f, a_y_f = CF.op.axis + + a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) + + a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) + a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f) + s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) + + (m, k) = CF.op.input_tensors[0].shape + (n, c, n_i, c_i) = CF.op.input_tensors[1].shape + n = n * n_i + + s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k))) + s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n))) + + if C == O: + fused = s[O].fuse(a_x3, a_y3) + else: + a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O) + a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O) + + s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi) + s[O].vectorize(a_xi) + + fused = s[O].fuse(a_x3, a_y3) + + if do_parallel: + s[O].parallel(fused) + + return s, fused -@autotvm.register_topi_schedule("dense_vnni.x86") -def schedule_dense_vnni(cfg, outs): - """Create a schedule for dense_vnni""" +@autotvm.register_topi_schedule("dense_amx_int8.x86") +def schedule_dense_amx_int8(cfg, outs): + """Create a schedule for dense_amx_int8""" s = te.create_schedule([x.op for x in outs]) def _callback(op): - if "dense_vnni" in op.tag: - dense_vnni_schedule(cfg, s, op.output(0), outs[0]) + if "dense_amx_int8" in op.tag: + dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) traverse_inline(s, outs[0].op, _callback) return s diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index fd2b184a87d2..2cb46b8291fb 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -25,13 +25,15 @@ from ..utils import get_const_tuple from ..nn import dense_alter_layout from .utils import target_has_vnni +from .utils import target_has_amx from .. import nn -def check_vnni_applicable(x, y, allow_padding=False): +def check_inst_applicable(x, y, allow_padding=False): mcpu = tvm.target.Target.current().mcpu + simd_avai = target_has_vnni(mcpu) or target_has_amx(mcpu) return ( - target_has_vnni(mcpu) + simd_avai and "int8" in x.dtype and "int8" in y.dtype and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0)) @@ -47,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - if check_vnni_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": + if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": weight_layout = "NC16n4c" return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) @@ -87,7 +89,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False): """Legalizes s8, s8 -> s32 GEMM op for VNNI.""" if ( - check_vnni_applicable(arg_types[0], arg_types[1], allow_padding=True) + check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True) and arg_types[0].dtype == "int8" ): x, y = inputs diff --git a/python/tvm/topi/x86/tensor_intrin.py b/python/tvm/topi/x86/tensor_intrin.py index 9e91e32b20e5..3b83fecbf552 100644 --- a/python/tvm/topi/x86/tensor_intrin.py +++ b/python/tvm/topi/x86/tensor_intrin.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Core kernel of dot product of 4 Int8 operations""" -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unused-variable import tvm from tvm import te import tvm.target.codegen @@ -348,3 +348,227 @@ def _instr(index): binds={data: a_buffer, kernel: b_buffer}, default_buffer_params=buffer_params, ) + + +def dot_32x128x32_u8s8s32_sapphirerapids(LDA): + """ + Int8 dot product by every 16x64 elements using AMX-TMUL Sapphire Rapids instructions. + The tdpxxd instruction takes two tile of uint8 and int8 datatype -- data[16][64] and + kernel[1][16][16][4] -- and computes a dot product of data[16][16] in int32 datatype. + + (Physically, to efficiently leveraging the tile register, we constructing a 2x2 tiles + matmul which performs 32x128x32 in total) + + The pseudo code is as follows: + for(k=0; k<2; k++){ + for(n=0; n<2; n++){ + tileload64(tmm_b, B) + for(m=0; m<2; m++){ + if(n==0) + tileload64(tmm_a, A) + tdpbusd(tmm_c, tmm_a, tmm_b) + } + } + } + + Args: + LDA (int): the stride of the matrix A, which is uint8 type and use it to determine + memory strides of macro reduce axis. + + Returns + ------- + intrin : TensorIntrin + The Sapphire Rapids AMX-TMUL int8 tdpbusd TensorIntrin that can be used in tensorizing + schedule + """ + A = te.placeholder((32, 128), name="A", dtype="uint8") + B = te.placeholder((2, 32, 16, 4), name="B", dtype="int8") + k = te.reduce_axis((0, 128), name="k") + + C = te.compute( + (32, 32), + lambda i, j: te.sum( + A[i, k].astype("int32") + * B[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(k, 4), j % 16, k % 4].astype("int32"), + axis=k, + ), + name="C", + ) + + BA = tvm.tir.decl_buffer( + A.shape, A.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BA" + ) + BB = tvm.tir.decl_buffer( + B.shape, + B.dtype, + offset_factor=1, + strides=[te.var("ldw"), te.var("ldw"), te.var("ldw"), 1], + name="BB", + ) + BC = tvm.tir.decl_buffer( + C.shape, C.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BC", scope="amx.tmm" + ) + + def intrin_func(ins, outs): # pylint: disable=unused-variable + bufA = ins[0] + bufB = ins[1] + bufC = outs[0] + + assert LDA + _strides_A = tvm.tir.const(LDA, dtype="uint64") + _strides_B_tile = tvm.tir.const(LDA / 128, dtype="uint64") + + def init(): + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(0, dtype="uint8"), + ) + ) # tile C 0 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(1, dtype="uint8"), + ) + ) # tile C 1 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(2, dtype="uint8"), + ) + ) # tile C 2 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(3, dtype="uint8"), + ) + ) # tile C 3 + + return ib.get() + + def body(): # load A, load B, dpbusd, store C + ib = tvm.tir.ir_builder.create() + + for k_tile in range(2): # reduced data blocks + for n_acc in range(2): # broadcast data blocks + tmm_B_ = tvm.tir.const(n_acc + 6, dtype="uint8") + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tileloaddt164", # load B: tmm6, tmm7 + tvm.tir.const(3, "uint8"), + tmm_B_, + bufB.access_ptr( + "r", offset=64 * 16 * (n_acc * 2 * _strides_B_tile + k_tile) + ), + tvm.tir.const(64, dtype="uint64"), + ) + ) + + for m_acc in range(2): # loaded data blocks + tmm_A_ = tvm.tir.const(m_acc + 4, dtype="uint8") + if n_acc == 0: + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tileloaddt164", # load A: , tmm4, tmm5 + tvm.tir.const(3, "uint8"), + tmm_A_, + bufA.access_ptr( + "r", offset=m_acc * 16 * _strides_A + k_tile * 64 + ), + _strides_A, + ) + ) + + tmm_C_ = tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8") + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tdpbusd", + tvm.tir.const(3, "uint8"), + tmm_C_, + tmm_A_, + tmm_B_, + ) + ) # tdpxxd + + return ib.get() + + # body, reset, store + return ( + body(), + init(), + body(), + ) + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + + +def acc_32x32_int32_sapphirerapids(LDC): + """ + Store the accumulated tile register in scope amx.tmm to global memory. + (tmm0, tmm1, tmm2, tmm3 --> global 4 tiles) + + Args: + LDC (int): the stride of the matrix C, which is int32 type and use it to + determine memory strides. + + Returns + ------- + intrin : TensorIntrin + The Sapphirerapids AMX-TMUL int8 tilestored64 TensorIntrin that can be used + in tensorizing schedule + """ + A = te.placeholder((32, 32), name="A", dtype="int32") + bufA = tvm.tir.decl_buffer( + A.shape, + A.dtype, + scope="amx.tmm", + name="a_buffer", + offset_factor=1, + strides=[te.var("ldw"), 1], + ) + + C = te.compute((32, 32), lambda i, j: A[i, j], name="C") + bufC = tvm.tir.decl_buffer( + C.shape, + C.dtype, + scope="global", + name="c_buffer", + offset_factor=1, + strides=[te.var("ldw"), 1], + ) + + assert LDC + _strides_C = tvm.tir.const(4 * LDC, dtype="uint64") + + def intrin_func(ins, outs): # pylint: disable=unused-variable + ib = tvm.tir.ir_builder.create() + bufA = ins[0] + bufC = outs[0] + for n_acc in range(2): # broadcast data blocks + for m_acc in range(2): # loaded data blocks + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilestored64", + tvm.tir.const(3, "uint8"), + tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8"), + bufC.access_ptr("w", offset=n_acc * 16 + m_acc * 16 * _strides_C / 4), + _strides_C, + ) + ) + + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: bufA, C: bufC}) diff --git a/python/tvm/topi/x86/utils.py b/python/tvm/topi/x86/utils.py index c364027022da..efe5913269a1 100644 --- a/python/tvm/topi/x86/utils.py +++ b/python/tvm/topi/x86/utils.py @@ -123,6 +123,13 @@ def target_has_vnni(target): } +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_amx") +def target_has_amx(target): + return target in { + "sapphirerapids", + } + + @tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") def get_simd_32bit_lanes(): mcpu = tvm.target.Target.current().mcpu diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc new file mode 100644 index 000000000000..2e034bd478b5 --- /dev/null +++ b/src/runtime/contrib/amx/amx_config.cc @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/runtime/contrib/amx/amx_config.cc + * \brief extraction of AMX configuration on x86 platforms + */ +#include +#include + +namespace tvm { +namespace runtime { + +#ifdef __linux__ +#include +#include +#include +#include +#include +#include +#include +#include + +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 + +typedef struct __tile_config { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[8]; /* Colum size of each tmm register in bytes */ + uint16_t reserved_1[8]; + uint8_t rows[8]; /* Row size of each tmm reg in bytes */ + uint8_t reserved_2[8]; +} __tilecfg; + +typedef union __union_tile_config { + __tilecfg s; + uint8_t a[64]; +} __tilecfg_u; + +void init_tile_config(__tilecfg_u* dst, uint16_t cols, uint8_t rows) { + dst->s.palette_id = 1; + dst->s.start_row = 0; + + for (int i = 0; i < 14; i++) dst->s.reserved_0[i] = 0; + + for (int i = 0; i < 8; i++) { + dst->s.colsb[i] = cols; + dst->s.rows[i] = rows; + dst->s.reserved_1[i] = 0; + dst->s.reserved_2[i] = 0; + } + + _tile_loadconfig(dst->a); +} + +TVM_REGISTER_GLOBAL("runtime.amx_tileconfig").set_body([](TVMArgs args, TVMRetValue* rv) { + int rows = args[0]; + int cols = args[1]; + LOG(INFO) << "rows: " << rows << ", cols:" << cols; + // -----------Config for AMX tile resgister---------------------- + __tilecfg_u cfg; + init_tile_config(&cfg, cols, rows); + + *rv = 1; + return; +}); + +// register a global packed function in c++,to init the system for AMX config +TVM_REGISTER_GLOBAL("runtime.amx_init").set_body([](TVMArgs args, TVMRetValue* rv) { + // -----------Detect and request for AMX control---------------------- + uint64_t bitmask = 0; + int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[0]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, TMUL feature is not allowed."; + return; + } + if (bitmask & XFEATURE_MASK_XTILEDATA) { + *rv = 1; + return; + } // TILE_DATA feature was not detected + + status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + // if XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed + if (0 != status) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[1]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed."; + return; + } + + status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + // if XFEATURE_XTILEDATA setup is failed, can't use TMUL + if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[2]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, can't use TMUL."; + return; + } + + // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed + *rv = 1; + return; +}); + +#endif +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 83477312dcc5..51dba038b6ac 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -62,6 +62,8 @@ enum class StorageRank { kWMMAAccumulator = 6, /*! \brief global scope texture memory */ kTexture = 7, + /*! \brief global scope amx tmm memory */ + kAMXTMM = 8, }; /*! @@ -149,6 +151,9 @@ struct StorageScope { } else if (s.compare(0, 7, "texture") == 0) { r.rank = StorageRank::kTexture; r.tag = s.substr(7, std::string::npos); + } else if (s.compare(0, 7, "amx.tmm") == 0) { + r.rank = StorageRank::kAMXTMM; + r.tag = s.substr(7, std::string::npos); } else { LOG(FATAL) << "unknown storage scope " << s; } diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index c0fc9881b4f5..51601cb93d6b 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -147,6 +147,10 @@ #define TVM_INFO_USE_MKL "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_AMX +#define TVM_INFO_USE_AMX "NOT-FOUND" +#endif + #ifndef TVM_INFO_USE_DNNL #define TVM_INFO_USE_DNNL "NOT-FOUND" #endif @@ -270,6 +274,7 @@ TVM_DLL Map GetLibInfo() { {"USE_CUDNN", TVM_INFO_USE_CUDNN}, {"USE_CUSTOM_LOGGING", TVM_INFO_USE_CUSTOM_LOGGING}, {"USE_CUTLASS", TVM_INFO_USE_CUTLASS}, + {"USE_AMX", TVM_INFO_USE_AMX}, {"USE_DNNL", TVM_INFO_USE_DNNL}, {"USE_ETHOSN", TVM_INFO_USE_ETHOSN}, {"USE_FALLBACK_STL_MAP", TVM_INFO_USE_FALLBACK_STL_MAP}, diff --git a/tests/python/contrib/test_amx.py b/tests/python/contrib/test_amx.py new file mode 100644 index 000000000000..30da7e56fb8d --- /dev/null +++ b/tests/python/contrib/test_amx.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition + +import tvm +from tvm import relay + +from tvm import te +import tvm.testing +from tvm.topi.x86.tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids +from tvm.topi.x86.tensor_intrin import acc_32x32_int32_sapphirerapids +import numpy as np +import pytest + + +@tvm.testing.requires_llvm +@pytest.mark.skip("skip due to AMX feature not avaliable yet") +def test_amx_u8s8s32_matmul_tensorize(): + m = 1024 + k = 1024 + n = 1024 + + # --------------------------Config--------------------------- + # Skip this test if "-mcpu=sapphirerapids" not supported by LLVM < 12.0 + target = "llvm -mcpu=sapphirerapids" + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("skip because %s is not enabled..." % target) + return + + amx_init = tvm.get_global_func("runtime.amx_init") + amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") + assert amx_init() + assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. + # --------------------------Compute-------------------------- + X = te.placeholder((m, k), name="X", dtype="uint8") + ak = te.reduce_axis((0, k), name="k") + packedW = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8") + + C = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packedW[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( + "int32" + ), + axis=ak, + ), + name="F", + ) + + # --------------------------Schedule-------------------------- + s = te.create_schedule(C.op) + a_x, a_y = C.op.axis + (a_k,) = C.op.reduce_axis + + CF = s.cache_write(C, "amx.tmm") + a_xo, a_xi = s[C].split(a_x, factor=32) + a_yo, a_yi = s[C].split(a_y, factor=32) + s[C].reorder(a_xo, a_yo, a_xi, a_yi) + + s[CF].compute_at(s[C], a_yo) + (a_k_f,) = CF.op.reduce_axis + a_x_f, a_y_f = CF.op.axis + + a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) + a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) + a_ko_f, a_ki_f = s[CF].split(a_k_f, factor=128) + s[CF].reorder(a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) + + s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=k)) + s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=n)) + + lib = tvm.build(s, [X, packedW, C], target, name="intrinsic") + asm = lib.get_source("asm") + assert "tilezero" in asm + assert "tileloaddt1" in asm + assert "tdpbusd" in asm + assert "tilestored" in asm + + # ----------------------- verify correctness -------------------------------- + # generate the plain data + a = np.random.uniform(1, 10, size=(m, k)).astype("uint8") + b = np.random.uniform(1, 10, size=(n, k)).astype("int8") + packW = np.random.uniform(1, 10, size=(n // 16, k // 4, 16, 4)).astype("int8") + + # This should occurs in pre_pack (constant folding) stage, + # from plain data to blocked data(NC16n4c) + for i_n in range(n): + for i_k in range(k): + packW[i_n // 16][i_k // 4][i_n % 16][i_k % 4] = b[i_n][i_k] + + x = tvm.nd.array(a, dev) + w = tvm.nd.array(packW, dev) + y = tvm.nd.array(np.zeros((m, n), dtype="int32"), dev) + t_evaluator = lib.time_evaluator(lib.entry_name, dev, number=100) + result = t_evaluator(x, w, y) + print(result) + tvm.testing.assert_allclose(y.numpy(), np.dot(a.astype("int32"), b.T.astype("int32")), rtol=0) + + +@tvm.testing.requires_llvm +@pytest.mark.skip("skip due to AMX feature not avaliable yet") +def test_amx_check_support(): + amx_init = tvm.get_global_func("runtime.amx_init") + amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") + assert amx_init() + assert amx_tileconfig(16, 64) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index bd4e1b72c3cd..9f31acfa6d7f 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -799,6 +799,53 @@ def test_dense_vnni(m, n, k): np.testing.assert_equal(out, ref) +@tvm.testing.requires_llvm +@pytest.mark.skip("skip due to AMX feature not avaliable yet") +def test_dense_amx_int8(): + data_shape = (32, 128) + weight_shape = (32, 128) + + amx_init = tvm.get_global_func("runtime.amx_init") + amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") + assert amx_init() + assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. + + for data_dtype in ["uint8", "int8"]: + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype="int8") + bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") + dense = relay.nn.dense(data, weight, out_dtype="int32") + out = relay.nn.bias_add(dense, bias) + mod = tvm.IRModule.from_expr(out) + + target = "llvm -mcpu=sapphirerapids" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target) + + asm = lib.lib.get_source("asm") + assert "tilezero" in asm + assert "tileloaddt1" in asm + assert "tdpbusd" in asm + assert "tilestored" in asm + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + a = np.random.uniform(1, 10, size=data_shape).astype(data_dtype) + b = np.random.uniform(1, 10, size=weight_shape).astype("int8") + c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") + + runtime.set_input("data", a) + runtime.set_input("weight", b) + runtime.set_input("bias", c) + runtime.run() + + out = runtime.get_output(0).numpy() + ref = np.dot(a.astype("int32"), b.transpose().astype("int32")) + c + + np.testing.assert_equal(out, ref) + + @pytest.mark.skip("Requires GFX10 AMDGPU") def test_dense_rocm_sdot4(): data_shape = (32, 96)