From 32308867667eeafc5c289ec228ce58e83da5e367 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Mon, 19 Dec 2022 01:01:42 -0500 Subject: [PATCH 01/21] add AMX config functions and building option. --- CMakeLists.txt | 2 + cmake/modules/LibInfo.cmake | 1 + cmake/modules/contrib/AMX.cmake | 23 +++++ src/runtime/contrib/amx/amx_config.cc | 143 ++++++++++++++++++++++++++ src/support/libinfo.cc | 5 + tests/python/contrib/test_amx.py | 33 ++++++ 6 files changed, 207 insertions(+) create mode 100644 cmake/modules/contrib/AMX.cmake create mode 100644 src/runtime/contrib/amx/amx_config.cc create mode 100644 tests/python/contrib/test_amx.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 119bf8325c8c..18b85a582279 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,6 +86,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson") # Contrib library options tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF) tvm_option(USE_BLAS "The blas library to be linked" none) +tvm_option(USE_AMX "Enable Intel AMX" OFF) tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) tvm_option(USE_DNNL "Enable DNNL codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) @@ -495,6 +496,7 @@ include(cmake/modules/contrib/EthosU.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) +include(cmake/modules/contrib/AMX.cmake) include(cmake/modules/contrib/CUTLASS.cmake) include(cmake/modules/contrib/ExampleTargetHooks.cmake) include(cmake/modules/contrib/Random.cmake) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 7c24088c0ad2..4fc7ed3262d5 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -65,6 +65,7 @@ function(add_lib_info src_file) TVM_INFO_USE_CUDNN="${USE_CUDNN}" TVM_INFO_USE_CUSTOM_LOGGING="${USE_CUSTOM_LOGGING}" TVM_INFO_USE_CUTLASS="${USE_CUTLASS}" + TVM_INFO_USE_AMX="${USE_AMX}" TVM_INFO_USE_DNNL="${USE_DNNL}" TVM_INFO_USE_ETHOSN="${USE_ETHOSN}" TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}" diff --git a/cmake/modules/contrib/AMX.cmake b/cmake/modules/contrib/AMX.cmake new file mode 100644 index 000000000000..ac349c4336a2 --- /dev/null +++ b/cmake/modules/contrib/AMX.cmake @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_AMX) + file(GLOB AMX_RUNTIME_CONFIG src/runtime/contrib/amx/amx_config.cc) + list(APPEND COMPILER_SRCS ${AMX_RUNTIME_CONFIG}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") + message(STATUS "Build with Intel AMX support...") +endif() diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc new file mode 100644 index 000000000000..f04126405a03 --- /dev/null +++ b/src/runtime/contrib/amx/amx_config.cc @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file src/runtime/contrib/amx/amx_config.cc + * \brief extraction of AMX configuration on x86 platforms + */ +#include +#include + +namespace tvm { +namespace runtime { + +#ifdef __linux__ +#include +#include +#include +#include +#include +#include +#include +#include + +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 + +typedef struct __tile_config { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[8]; /* Colum size of each tmm register in bytes */ + uint16_t reserved_1[8]; + uint8_t rows[8]; /* Row size of each tmm reg in bytes */ + uint8_t reserved_2[8]; +} __tilecfg; + +typedef union __union_tile_config { + __tilecfg s; + uint8_t a[64]; +} __tilecfg_u; + +void init_tile_config(__tilecfg_u* dst, uint16_t cols, uint8_t rows) { + dst->s.palette_id = 1; + dst->s.start_row = 0; + + for (int i = 0; i < 14; i++) dst->s.reserved_0[i] = 0; + + for (int i = 0; i < 8; i++) { + dst->s.colsb[i] = cols; + dst->s.rows[i] = rows; + dst->s.reserved_1[i] = 0; + dst->s.reserved_2[i] = 0; + } + + _tile_loadconfig(dst->a); +} + +TVM_REGISTER_GLOBAL("runtime.amx_tileconfig").set_body([](TVMArgs args, TVMRetValue* rv) { + int rows = args[0]; + int cols = args[1]; + LOG(INFO) << "rows: " << rows << ", cols:" << cols; + // -----------Config for AMX tile resgister---------------------- + __tilecfg_u cfg; + init_tile_config(&cfg, cols, rows); + + *rv = 1; + return; +}); + +// register a global packed function in c++,to init the system for AMX config +TVM_REGISTER_GLOBAL("runtime.amx_init").set_body([](TVMArgs args, TVMRetValue* rv) { + // -----------Enlarge the signal stack in linux---------------------- + char largestack[15535]; // SIGSTKSZ=8192, MINSIGSTKSZ=2048, 15535 + stack_t ss; + ss.ss_sp = largestack; + ss.ss_size = sizeof(largestack); + ss.ss_flags = 0; + if (sigaltstack(&ss, NULL)) exit(-1); + + // -----------Detect and request for AMX control---------------------- + uint64_t bitmask = 0; + int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[0]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, TMUL feature is not allowed."; + return; + } + if (bitmask & XFEATURE_MASK_XTILEDATA) { + *rv = 1; + return; + } // TILE_DATA feature was not detected + + status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + // if XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed + if (0 != status) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[1]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed."; + return; + } + + status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + // if XFEATURE_XTILEDATA setup is failed, can't use TMUL + if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) { + *rv = 0; + LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); + LOG(FATAL) << "status[2]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, can't use TMUL."; + return; + } + + // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed + *rv = 1; + return; +}); + +#endif +} // namespace runtime +} // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index c0fc9881b4f5..51601cb93d6b 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -147,6 +147,10 @@ #define TVM_INFO_USE_MKL "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_AMX +#define TVM_INFO_USE_AMX "NOT-FOUND" +#endif + #ifndef TVM_INFO_USE_DNNL #define TVM_INFO_USE_DNNL "NOT-FOUND" #endif @@ -270,6 +274,7 @@ TVM_DLL Map GetLibInfo() { {"USE_CUDNN", TVM_INFO_USE_CUDNN}, {"USE_CUSTOM_LOGGING", TVM_INFO_USE_CUSTOM_LOGGING}, {"USE_CUTLASS", TVM_INFO_USE_CUTLASS}, + {"USE_AMX", TVM_INFO_USE_AMX}, {"USE_DNNL", TVM_INFO_USE_DNNL}, {"USE_ETHOSN", TVM_INFO_USE_ETHOSN}, {"USE_FALLBACK_STL_MAP", TVM_INFO_USE_FALLBACK_STL_MAP}, diff --git a/tests/python/contrib/test_amx.py b/tests/python/contrib/test_amx.py new file mode 100644 index 000000000000..090f21343b8c --- /dev/null +++ b/tests/python/contrib/test_amx.py @@ -0,0 +1,33 @@ +import pytest +import itertools +import numpy as np +import os +import sys +import subprocess +import math +import collections + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.relay.op.contrib import dnnl +import tvm.testing + + +def test_amx_tensorize(dtypt="int8"): + pass + +def test_amx_check_support(): + amx_init = tvm.get_global_func("runtime.amx_init") + amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") + if not amx_init(): + print("[ ERROR ] AMX not inited !!!!!") + if not amx_tileconfig(16, 64): + print("[ ERROR ] AMX not tile configed !!!!!") + + +if __name__ == "__main__": + print("[Test for TVM - LLVM - AMX intrinsic call pid: {}]".format(os.getpid()) ) + test_amx_check_support() From 2312242ddb74bf546eb6f54b3c5b09c1cba5553b Mon Sep 17 00:00:00 2001 From: Qianshui Date: Mon, 19 Dec 2022 03:02:58 -0500 Subject: [PATCH 02/21] amx tensor intrinsics and u8s8s32 matmul testcase --- python/tvm/topi/x86/tensor_intrin.py | 141 +++++++++++++++++++ python/tvm/topi/x86/utils.py | 7 + src/runtime/thread_storage_scope.h | 5 + tests/python/contrib/test_amx.py | 129 ++++++++++++++--- tests/python/contrib/test_gemm_acc32_vnni.py | 2 +- 5 files changed, 262 insertions(+), 22 deletions(-) diff --git a/python/tvm/topi/x86/tensor_intrin.py b/python/tvm/topi/x86/tensor_intrin.py index 9e91e32b20e5..90857dbae160 100644 --- a/python/tvm/topi/x86/tensor_intrin.py +++ b/python/tvm/topi/x86/tensor_intrin.py @@ -348,3 +348,144 @@ def _instr(index): binds={data: a_buffer, kernel: b_buffer}, default_buffer_params=buffer_params, ) + +def dot_32x128x32_u8s8s32_sapphirerapids(LDA): + """ + Int8 dot product by every 16x64 elements using AMX-TMUL Sapphire Rapids instructions. + The tdpxxd instruction takes two tile of uint8 and int8 datatype -- data[16][64] and + kernel[1][16][16][4] -- and computes a dot product of data[16][16] in int32 datatype. + + (Physically, to efficiently leveraging the tile register, we constructing a 2x2 tiles + matmul which performs 32x128x32 in total) + + The pseudo code is as follows: + for(k=0; k<2; k++){ + for(n=0; n<2; n++){ + tileload64(tmm_b, B) + for(m=0; m<2; m++){ + if(n==0) + tileload64(tmm_a, A) + tdpbusd(tmm_c, tmm_a, tmm_b) + } + } + } + + Args: + LDA (int): the stride of the matrix A, which is uint8 type and + use it to determine memory strides of macro reduce axis. + + Returns + ------- + intrin : TensorIntrin + The Sapphire Rapids AMX-TMUL int8 tdpbusd TensorIntrin that can be used in tensorizing schedule + """ + A = te.placeholder((32, 128), name="A", dtype="uint8") + B = te.placeholder((2, 32, 16, 4), name="B", dtype="int8") + k = te.reduce_axis((0, 128), name="k") + + C = te.compute( + (32, 32), + lambda i, j: te.sum( + A[i, k].astype("int32") + * B[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(k, 4), j % 16, k % 4].astype("int32"), + axis=k, + ), + name="C", + ) + + BA = tvm.tir.decl_buffer(A.shape, A.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BA") + BB = tvm.tir.decl_buffer(B.shape, B.dtype, offset_factor=1, strides=[te.var("ldw"), te.var("ldw"), te.var("ldw"), 1], name="BB") + BC = tvm.tir.decl_buffer(C.shape, C.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BC",scope="amx.tmm") + + def intrin_func(ins, outs): + bufA = ins[0] + bufB = ins[1] + bufC = outs[0] + + assert LDA + _strides_A = tvm.tir.const(LDA, dtype="uint64") + _strides_B_tile = tvm.tir.const(LDA/128, dtype="uint64") + + def init(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(0, dtype="uint8") )) # tile C 0 + ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(1, dtype="uint8") )) # tile C 1 + ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(2, dtype="uint8") )) # tile C 2 + ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(3, dtype="uint8") )) # tile C 3 + + return ib.get() + + def body(): # load A, load B, dpbusd, store C + ib = tvm.tir.ir_builder.create() + + for k_tile in range(2): # reduced data blocks + for n_acc in range(2): # broadcast data blocks + tmm_B_ = tvm.tir.const(n_acc+6, dtype="uint8") + ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tileloaddt164", # load B: tmm6, tmm7 + tvm.tir.const(3, "uint8"), + tmm_B_, bufB.access_ptr("r", offset=64*16*(n_acc*2*_strides_B_tile + k_tile)), tvm.tir.const(64, dtype="uint64"))) + + for m_acc in range(2): # loaded data blocks + tmm_A_ = tvm.tir.const(m_acc+4, dtype="uint8") + if n_acc == 0: + ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tileloaddt164", # load A: , tmm4, tmm5 + tvm.tir.const(3, "uint8"), + tmm_A_, bufA.access_ptr("r", offset=m_acc*16*_strides_A + k_tile*64), _strides_A)) + + tmm_C_ = tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8") + ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tdpbusd", + tvm.tir.const(3, "uint8"), + tmm_C_, tmm_A_, tmm_B_)) # tdpxxd + + return ib.get() + + # body, reset, store + return body(), init(), body(), + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + +def acc_32x32_int32_sapphirerapids(LDC): + """ + Store the accumulated tile register in scope amx.tmm to global memory. + (tmm0, tmm1, tmm2, tmm3 --> global 4 tiles) + + Args: + LDC (int): the stride of the matrix C, which is int32 type and + use it to determine memory strides. + + Returns + ------- + intrin : TensorIntrin + The Sapphirerapids AMX-TMUL int8 tilestored64 TensorIntrin that can be used in tensorizing schedule + """ + A = te.placeholder((32, 32), name="A", dtype="int32") + bufA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="amx.tmm", name="a_buffer", offset_factor=1, strides=[te.var("ldw"), 1]) + + C = te.compute((32, 32), lambda i, j: A[i, j], name="C") + bufC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="global", name="c_buffer", offset_factor=1, strides=[te.var("ldw"), 1]) + + assert LDC + _strides_C = tvm.tir.const(4*LDC, dtype="uint64") + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + bufA = ins[0] + bufC = outs[0] + for n_acc in range(2): # broadcast data blocks + for m_acc in range(2): # loaded data blocks + ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilestored64", + tvm.tir.const(3, "uint8"), + tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8"), + bufC.access_ptr( "w", offset=n_acc*16 + m_acc*16*_strides_C/4), _strides_C)) + + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: bufA, C: bufC}) \ No newline at end of file diff --git a/python/tvm/topi/x86/utils.py b/python/tvm/topi/x86/utils.py index c364027022da..efe5913269a1 100644 --- a/python/tvm/topi/x86/utils.py +++ b/python/tvm/topi/x86/utils.py @@ -123,6 +123,13 @@ def target_has_vnni(target): } +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_amx") +def target_has_amx(target): + return target in { + "sapphirerapids", + } + + @tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") def get_simd_32bit_lanes(): mcpu = tvm.target.Target.current().mcpu diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 83477312dcc5..51dba038b6ac 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -62,6 +62,8 @@ enum class StorageRank { kWMMAAccumulator = 6, /*! \brief global scope texture memory */ kTexture = 7, + /*! \brief global scope amx tmm memory */ + kAMXTMM = 8, }; /*! @@ -149,6 +151,9 @@ struct StorageScope { } else if (s.compare(0, 7, "texture") == 0) { r.rank = StorageRank::kTexture; r.tag = s.substr(7, std::string::npos); + } else if (s.compare(0, 7, "amx.tmm") == 0) { + r.rank = StorageRank::kAMXTMM; + r.tag = s.substr(7, std::string::npos); } else { LOG(FATAL) << "unknown storage scope " << s; } diff --git a/tests/python/contrib/test_amx.py b/tests/python/contrib/test_amx.py index 090f21343b8c..a49b7f30edb6 100644 --- a/tests/python/contrib/test_amx.py +++ b/tests/python/contrib/test_amx.py @@ -1,33 +1,120 @@ -import pytest -import itertools -import numpy as np -import os -import sys -import subprocess -import math -import collections +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition import tvm -from tvm import relay -from tvm.relay import transform -from tvm.relay.build_module import bind_params_by_name -from tvm.relay.testing.temp_op_attr import TempOpAttr -from tvm.relay.op.contrib import dnnl +from tvm import te import tvm.testing +from tvm.topi.x86.tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids +from tvm.topi.x86.tensor_intrin import acc_32x32_int32_sapphirerapids +import numpy as np +import pytest + +@tvm.testing.requires_llvm +@pytest.mark.skip("skip because AMX feature not avaliable yet") +def test_amx_u8s8s32_matmul_tensorize(): + m = 1024 + k = 1024 + n = 1024 + + # --------------------------Config--------------------------- + # Skip this test if "-mcpu=sapphirerapids" not supported by LLVM < 12.0 + target="llvm -mcpu=sapphirerapids" + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("skip because %s is not enabled..." % target) + return + + amx_init = tvm.get_global_func("runtime.amx_init") + amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") + assert amx_init() + assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. + # --------------------------Compute-------------------------- + X = te.placeholder((m, k), name="X", dtype="uint8") + ak = te.reduce_axis((0, k), name="k") + packedW = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8") + + C = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packedW[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype("int32"), + axis=ak, + ), + name="F", + ) + # --------------------------Schedule-------------------------- + s = te.create_schedule(C.op) + a_x, a_y = C.op.axis + (a_k,) = C.op.reduce_axis -def test_amx_tensorize(dtypt="int8"): - pass + CF = s.cache_write(C, "amx.tmm") + a_xo, a_xi = s[C].split(a_x, factor=32) + a_yo, a_yi = s[C].split(a_y, factor=32) + s[C].reorder(a_xo, a_yo, a_xi, a_yi) + + s[CF].compute_at(s[C], a_yo) + (a_k_f,) = CF.op.reduce_axis + a_x_f, a_y_f = CF.op.axis + + a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) + a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) + a_ko_f, a_ki_f = s[CF].split(a_k_f, factor=128) + s[CF].reorder(a_ko_f, a_xo_f, a_yo_f, + a_ki_f, a_xi_f, a_yi_f) + + s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=k)) + s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=n)) + + lib = tvm.build(s, [X, packedW, C], target, name="intrinsic") + asm = lib.get_source("asm") + assert "tilezero" in asm + assert "tileloaddt1" in asm + assert "tdpbusd" in asm + assert "tilestored" in asm + + # ----------------------- verify correctness -------------------------------- + # generate the plain data + a = np.random.uniform(1, 10, size=(m, k)).astype("uint8") + b = np.random.uniform(1, 10, size=(n, k)).astype("int8") + packW = np.random.uniform(1, 10, size=(n // 16, k // 4, 16, 4)).astype("int8") + + # This should occurs in pre_pack (constant folding) stage, + # from plain data to blocked data(NC16n4c) + for i_n in range(n): + for i_k in range(k): + packW[i_n//16][i_k//4][i_n%16][i_k%4] = b[i_n][i_k] + + x = tvm.nd.array(a, dev) + w = tvm.nd.array(packW, dev) + y = tvm.nd.array(np.zeros((m, n), dtype="int32"), dev) + t_evaluator = lib.time_evaluator(lib.entry_name, dev, number=100) + result = t_evaluator(x, w, y) + print(result) + tvm.testing.assert_allclose(y.numpy(), np.dot(a.astype("int32"), b.T.astype("int32")), rtol=0) + print("[TEST PASS ! ! ! ]") def test_amx_check_support(): amx_init = tvm.get_global_func("runtime.amx_init") amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") - if not amx_init(): - print("[ ERROR ] AMX not inited !!!!!") - if not amx_tileconfig(16, 64): - print("[ ERROR ] AMX not tile configed !!!!!") - + assert amx_init() + assert amx_tileconfig(16, 64) if __name__ == "__main__": - print("[Test for TVM - LLVM - AMX intrinsic call pid: {}]".format(os.getpid()) ) test_amx_check_support() + test_amx_u8s8s32_matmul_tensorize() \ No newline at end of file diff --git a/tests/python/contrib/test_gemm_acc32_vnni.py b/tests/python/contrib/test_gemm_acc32_vnni.py index 9cec823cc58a..9121da21476a 100644 --- a/tests/python/contrib/test_gemm_acc32_vnni.py +++ b/tests/python/contrib/test_gemm_acc32_vnni.py @@ -115,5 +115,5 @@ def verify(target="llvm -mcpu=cascadelake"): # The test requires Cascade Lake and newer Intel machines to generate the # correct AVX512 VNNI instruction. So, disabling the test. - # test_fc_int8_acc32() + # test_fc_int8_acc32()`` pass From 3e2fc4e7ac77123938397b263e73b8876a2e9a8e Mon Sep 17 00:00:00 2001 From: Qianshui Date: Mon, 19 Dec 2022 04:24:37 -0500 Subject: [PATCH 03/21] add int8 dense kernel use amx tensorize --- python/tvm/relay/op/strategy/x86.py | 20 ++- python/tvm/topi/x86/dense.py | 150 ++++++++++++++++++++ python/tvm/topi/x86/dense_alter_op.py | 10 +- python/tvm/topi/x86/tensor_intrin.py | 195 ++++++++++++++++++-------- tests/python/contrib/test_amx.py | 34 +++-- tests/python/relay/test_op_level1.py | 42 ++++++ 6 files changed, 374 insertions(+), 77 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 7ff4dbc0ad1b..ba310cae21cd 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -25,7 +25,7 @@ from tvm.relay.ty import is_dynamic from tvm.target import Target from tvm.te import SpecializedCondition -from tvm.topi.x86.utils import target_has_vnni +from tvm.topi.x86.utils import target_has_vnni, target_has_amx from .. import op as _op from .generic import * @@ -591,9 +591,23 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def dense_pack_strategy_cpu(attrs, inputs, out_type, target): """dense_pack x86 strategy""" strategy = _op.OpStrategy() - + mcpu = Target.current().mcpu if ( - inputs[0].dtype == "uint8" + target_has_amx(mcpu) + and inputs[0].dtype == "uint8" + and inputs[1].dtype == "int8" + and out_type.dtype == "int32" + and attrs["weight_layout"] == "NC16n4c" + ): + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_amx_int8), + wrap_topi_schedule(topi.x86.schedule_dense_amx_int8), + name="dense_amx_int8.x86", + plevel=13, + ) + elif ( + target_has_vnni(mcpu) + and inputs[0].dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" and attrs["weight_layout"] == "NC16n4c" diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 65a803781a57..d7e971c116e7 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -27,6 +27,8 @@ from .. import generic, tag from ..utils import get_const_tuple, traverse_inline from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake +from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids +from .tensor_intrin import acc_32x32_int32_sapphirerapids from .utils import get_simd_32bit_lanes @@ -373,6 +375,154 @@ def _callback(op): return s +def dense_amx_int8_compute(cfg, data, packed_w, bias=None): + """Compute for uint8 x int8 -> int32 dense""" + m, k = data.shape + n_o, _, n_i, _ = packed_w.shape + ak = te.reduce_axis((0, k), name="k") + + C = te.compute( + (m, n_o * n_i), + lambda i, j: te.sum( + data[i, ak].astype("int32") + * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( + "int32" + ), + axis=ak, + ), + tag="dense_amx_int8", + attrs={"schedule_rule": "meta_schedule.dense_amx_int8"}, + ) + + if bias is not None: + C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) + + return C + + +def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True): + """Schedule dense compute using AMX TMUL instruction""" + # C: The output of GEMM + # O: The output of the fused op + def split_x(out): + default_x_split_factor1 = 32 + default_x_split_factor2 = 2 + default_x_split_factor3 = 2 + default_x_split_factor4 = 2 + a_x = s[out].op.axis[-2] + + if cfg.is_fallback: + a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1) + a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2) + a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3) + a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4) + return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi] + + cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: x.size[-1] == 32) + return cfg["tile_x"].apply(s, out, a_x) + + def split_y(out): + default_y_split_factor1 = 32 + default_y_split_factor2 = 4 + default_y_split_factor3 = 4 + default_y_split_factor4 = 4 + a_y = s[out].op.axis[-1] + + if cfg.is_fallback: + a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1) + a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2) + a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3) + a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4) + return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo] + + cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: y.size[-1] == 32) + return cfg["tile_y"].apply(s, out, a_y) + + def split_k(out, rd_axis): + default_k_split_factor1 = 128 + default_k_split_factor2 = 2 + default_k_split_factor3 = 2 + default_k_split_factor4 = 2 + + if cfg.is_fallback: + a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1) + a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2) + a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3) + a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4) + return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki] + + cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128) + return cfg["tile_k"].apply(s, out, rd_axis) + + a_x, a_y = C.op.axis + (a_k,) = C.op.reduce_axis + CF = s.cache_write(C, "amx.tmm") + + a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C) + a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C) + s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi) + + s[CF].compute_at(s[C], a_yo) + + (a_k_f,) = CF.op.reduce_axis + a_x_f, a_y_f = CF.op.axis + + a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) + + a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) + a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f) + s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) + + (m, k) = CF.op.input_tensors[0].shape + (n, c, n_i, c_i) = CF.op.input_tensors[1].shape + n = n * n_i + + s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k))) + s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n))) + + if C == O: + fused = s[O].fuse(a_x3, a_y3) + else: + a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O) + a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O) + + s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi) + s[O].vectorize(a_xi) + s[C].compute_at(s[O], a_y2) + + fused = s[O].fuse(a_y3, a_x3) + + if do_parallel: + s[O].parallel(fused) + + return s, fused + + +@autotvm.register_topi_compute("dense_amx_int8.x86") +def dense_amx_int8(cfg, data, weight, bias=None, out_dtype=None): + """Compute for uint8 x int8 -> int32 dense""" + if out_dtype is None: + out_dtype = data.dtype + assert len(weight.shape) == 4 + assert data.dtype == "uint8" and weight.dtype == "int8" + _, _, _, n_inner = get_const_tuple(weight.shape) # out_dim + assert n_inner == 4 + return dense_amx_int8_compute(cfg, data, weight, bias) + + +@autotvm.register_topi_schedule("dense_amx_int8.x86") +def schedule_dense_amx_int8(cfg, outs): + """Create a schedule for dense_amx_int8""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "dense_amx_int8" in op.tag: + dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + + def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib): """Compute matmul/dense using a BLAS library""" M, K = get_const_tuple(tensor_a.shape) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index fd2b184a87d2..9a2fd4b44ed7 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -25,13 +25,15 @@ from ..utils import get_const_tuple from ..nn import dense_alter_layout from .utils import target_has_vnni +from .utils import target_has_amx from .. import nn -def check_vnni_applicable(x, y, allow_padding=False): +def check_inst_applicable(x, y, allow_padding=False): mcpu = tvm.target.Target.current().mcpu + cpu_avai = target_has_vnni(mcpu) or target_has_amx(mcpu) return ( - target_has_vnni(mcpu) + cpu_avai and "int8" in x.dtype and "int8" in y.dtype and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0)) @@ -47,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - if check_vnni_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": + if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": weight_layout = "NC16n4c" return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) @@ -87,7 +89,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False): """Legalizes s8, s8 -> s32 GEMM op for VNNI.""" if ( - check_vnni_applicable(arg_types[0], arg_types[1], allow_padding=True) + check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True) and arg_types[0].dtype == "int8" ): x, y = inputs diff --git a/python/tvm/topi/x86/tensor_intrin.py b/python/tvm/topi/x86/tensor_intrin.py index 90857dbae160..ffa32d74c467 100644 --- a/python/tvm/topi/x86/tensor_intrin.py +++ b/python/tvm/topi/x86/tensor_intrin.py @@ -349,13 +349,14 @@ def _instr(index): default_buffer_params=buffer_params, ) + def dot_32x128x32_u8s8s32_sapphirerapids(LDA): - """ + """ Int8 dot product by every 16x64 elements using AMX-TMUL Sapphire Rapids instructions. The tdpxxd instruction takes two tile of uint8 and int8 datatype -- data[16][64] and kernel[1][16][16][4] -- and computes a dot product of data[16][16] in int32 datatype. - (Physically, to efficiently leveraging the tile register, we constructing a 2x2 tiles + (Physically, to efficiently leveraging the tile register, we constructing a 2x2 tiles matmul which performs 32x128x32 in total) The pseudo code is as follows: @@ -371,13 +372,14 @@ def dot_32x128x32_u8s8s32_sapphirerapids(LDA): } Args: - LDA (int): the stride of the matrix A, which is uint8 type and - use it to determine memory strides of macro reduce axis. + LDA (int): the stride of the matrix A, which is uint8 type and use it to determine + memory strides of macro reduce axis. Returns ------- intrin : TensorIntrin - The Sapphire Rapids AMX-TMUL int8 tdpbusd TensorIntrin that can be used in tensorizing schedule + The Sapphire Rapids AMX-TMUL int8 tdpbusd TensorIntrin that can be used in tensorizing + schedule """ A = te.placeholder((32, 128), name="A", dtype="uint8") B = te.placeholder((2, 32, 16, 4), name="B", dtype="int8") @@ -393,99 +395,180 @@ def dot_32x128x32_u8s8s32_sapphirerapids(LDA): name="C", ) - BA = tvm.tir.decl_buffer(A.shape, A.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BA") - BB = tvm.tir.decl_buffer(B.shape, B.dtype, offset_factor=1, strides=[te.var("ldw"), te.var("ldw"), te.var("ldw"), 1], name="BB") - BC = tvm.tir.decl_buffer(C.shape, C.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BC",scope="amx.tmm") + BA = tvm.tir.decl_buffer( + A.shape, A.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BA" + ) + BB = tvm.tir.decl_buffer( + B.shape, + B.dtype, + offset_factor=1, + strides=[te.var("ldw"), te.var("ldw"), te.var("ldw"), 1], + name="BB", + ) + BC = tvm.tir.decl_buffer( + C.shape, C.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BC", scope="amx.tmm" + ) def intrin_func(ins, outs): - bufA = ins[0] + bufA = ins[0] bufB = ins[1] bufC = outs[0] assert LDA _strides_A = tvm.tir.const(LDA, dtype="uint64") - _strides_B_tile = tvm.tir.const(LDA/128, dtype="uint64") - + _strides_B_tile = tvm.tir.const(LDA / 128, dtype="uint64") + def init(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", - tvm.tir.const(1, "uint8"), - tvm.tir.const(0, dtype="uint8") )) # tile C 0 - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", - tvm.tir.const(1, "uint8"), - tvm.tir.const(1, dtype="uint8") )) # tile C 1 - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", - tvm.tir.const(1, "uint8"), - tvm.tir.const(2, dtype="uint8") )) # tile C 2 - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", - tvm.tir.const(1, "uint8"), - tvm.tir.const(3, dtype="uint8") )) # tile C 3 + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(0, dtype="uint8"), + ) + ) # tile C 0 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(1, dtype="uint8"), + ) + ) # tile C 1 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(2, dtype="uint8"), + ) + ) # tile C 2 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(3, dtype="uint8"), + ) + ) # tile C 3 return ib.get() - def body(): # load A, load B, dpbusd, store C + def body(): # load A, load B, dpbusd, store C ib = tvm.tir.ir_builder.create() - - for k_tile in range(2): # reduced data blocks - for n_acc in range(2): # broadcast data blocks - tmm_B_ = tvm.tir.const(n_acc+6, dtype="uint8") - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tileloaddt164", # load B: tmm6, tmm7 - tvm.tir.const(3, "uint8"), - tmm_B_, bufB.access_ptr("r", offset=64*16*(n_acc*2*_strides_B_tile + k_tile)), tvm.tir.const(64, dtype="uint64"))) - - for m_acc in range(2): # loaded data blocks - tmm_A_ = tvm.tir.const(m_acc+4, dtype="uint8") + + for k_tile in range(2): # reduced data blocks + for n_acc in range(2): # broadcast data blocks + tmm_B_ = tvm.tir.const(n_acc + 6, dtype="uint8") + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tileloaddt164", # load B: tmm6, tmm7 + tvm.tir.const(3, "uint8"), + tmm_B_, + bufB.access_ptr( + "r", offset=64 * 16 * (n_acc * 2 * _strides_B_tile + k_tile) + ), + tvm.tir.const(64, dtype="uint64"), + ) + ) + + for m_acc in range(2): # loaded data blocks + tmm_A_ = tvm.tir.const(m_acc + 4, dtype="uint8") if n_acc == 0: - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tileloaddt164", # load A: , tmm4, tmm5 - tvm.tir.const(3, "uint8"), - tmm_A_, bufA.access_ptr("r", offset=m_acc*16*_strides_A + k_tile*64), _strides_A)) + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tileloaddt164", # load A: , tmm4, tmm5 + tvm.tir.const(3, "uint8"), + tmm_A_, + bufA.access_ptr( + "r", offset=m_acc * 16 * _strides_A + k_tile * 64 + ), + _strides_A, + ) + ) tmm_C_ = tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8") - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tdpbusd", - tvm.tir.const(3, "uint8"), - tmm_C_, tmm_A_, tmm_B_)) # tdpxxd + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tdpbusd", + tvm.tir.const(3, "uint8"), + tmm_C_, + tmm_A_, + tmm_B_, + ) + ) # tdpxxd return ib.get() # body, reset, store - return body(), init(), body(), + return ( + body(), + init(), + body(), + ) return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + def acc_32x32_int32_sapphirerapids(LDC): - """ + """ Store the accumulated tile register in scope amx.tmm to global memory. (tmm0, tmm1, tmm2, tmm3 --> global 4 tiles) Args: - LDC (int): the stride of the matrix C, which is int32 type and - use it to determine memory strides. + LDC (int): the stride of the matrix C, which is int32 type and use it to + determine memory strides. Returns ------- intrin : TensorIntrin - The Sapphirerapids AMX-TMUL int8 tilestored64 TensorIntrin that can be used in tensorizing schedule + The Sapphirerapids AMX-TMUL int8 tilestored64 TensorIntrin that can be used + in tensorizing schedule """ A = te.placeholder((32, 32), name="A", dtype="int32") - bufA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="amx.tmm", name="a_buffer", offset_factor=1, strides=[te.var("ldw"), 1]) + bufA = tvm.tir.decl_buffer( + A.shape, + A.dtype, + scope="amx.tmm", + name="a_buffer", + offset_factor=1, + strides=[te.var("ldw"), 1], + ) C = te.compute((32, 32), lambda i, j: A[i, j], name="C") - bufC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="global", name="c_buffer", offset_factor=1, strides=[te.var("ldw"), 1]) + bufC = tvm.tir.decl_buffer( + C.shape, + C.dtype, + scope="global", + name="c_buffer", + offset_factor=1, + strides=[te.var("ldw"), 1], + ) assert LDC - _strides_C = tvm.tir.const(4*LDC, dtype="uint64") + _strides_C = tvm.tir.const(4 * LDC, dtype="uint64") def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() bufA = ins[0] bufC = outs[0] - for n_acc in range(2): # broadcast data blocks - for m_acc in range(2): # loaded data blocks - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilestored64", - tvm.tir.const(3, "uint8"), - tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8"), - bufC.access_ptr( "w", offset=n_acc*16 + m_acc*16*_strides_C/4), _strides_C)) + for n_acc in range(2): # broadcast data blocks + for m_acc in range(2): # loaded data blocks + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilestored64", + tvm.tir.const(3, "uint8"), + tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8"), + bufC.access_ptr("w", offset=n_acc * 16 + m_acc * 16 * _strides_C / 4), + _strides_C, + ) + ) return ib.get() - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: bufA, C: bufC}) \ No newline at end of file + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: bufA, C: bufC}) diff --git a/tests/python/contrib/test_amx.py b/tests/python/contrib/test_amx.py index a49b7f30edb6..30da7e56fb8d 100644 --- a/tests/python/contrib/test_amx.py +++ b/tests/python/contrib/test_amx.py @@ -17,6 +17,8 @@ # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition import tvm +from tvm import relay + from tvm import te import tvm.testing from tvm.topi.x86.tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids @@ -24,25 +26,26 @@ import numpy as np import pytest + @tvm.testing.requires_llvm -@pytest.mark.skip("skip because AMX feature not avaliable yet") +@pytest.mark.skip("skip due to AMX feature not avaliable yet") def test_amx_u8s8s32_matmul_tensorize(): m = 1024 k = 1024 n = 1024 - + # --------------------------Config--------------------------- # Skip this test if "-mcpu=sapphirerapids" not supported by LLVM < 12.0 - target="llvm -mcpu=sapphirerapids" + target = "llvm -mcpu=sapphirerapids" dev = tvm.device(target, 0) if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return - + amx_init = tvm.get_global_func("runtime.amx_init") amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") assert amx_init() - assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. + assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. # --------------------------Compute-------------------------- X = te.placeholder((m, k), name="X", dtype="uint8") ak = te.reduce_axis((0, k), name="k") @@ -52,7 +55,9 @@ def test_amx_u8s8s32_matmul_tensorize(): (m, n), lambda i, j: te.sum( X[i, ak].astype("int32") - * packedW[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype("int32"), + * packedW[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( + "int32" + ), axis=ak, ), name="F", @@ -69,14 +74,13 @@ def test_amx_u8s8s32_matmul_tensorize(): s[C].reorder(a_xo, a_yo, a_xi, a_yi) s[CF].compute_at(s[C], a_yo) - (a_k_f,) = CF.op.reduce_axis + (a_k_f,) = CF.op.reduce_axis a_x_f, a_y_f = CF.op.axis a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) a_ko_f, a_ki_f = s[CF].split(a_k_f, factor=128) - s[CF].reorder(a_ko_f, a_xo_f, a_yo_f, - a_ki_f, a_xi_f, a_yi_f) + s[CF].reorder(a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=k)) s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=n)) @@ -87,7 +91,7 @@ def test_amx_u8s8s32_matmul_tensorize(): assert "tileloaddt1" in asm assert "tdpbusd" in asm assert "tilestored" in asm - + # ----------------------- verify correctness -------------------------------- # generate the plain data a = np.random.uniform(1, 10, size=(m, k)).astype("uint8") @@ -98,7 +102,7 @@ def test_amx_u8s8s32_matmul_tensorize(): # from plain data to blocked data(NC16n4c) for i_n in range(n): for i_k in range(k): - packW[i_n//16][i_k//4][i_n%16][i_k%4] = b[i_n][i_k] + packW[i_n // 16][i_k // 4][i_n % 16][i_k % 4] = b[i_n][i_k] x = tvm.nd.array(a, dev) w = tvm.nd.array(packW, dev) @@ -107,14 +111,16 @@ def test_amx_u8s8s32_matmul_tensorize(): result = t_evaluator(x, w, y) print(result) tvm.testing.assert_allclose(y.numpy(), np.dot(a.astype("int32"), b.T.astype("int32")), rtol=0) - print("[TEST PASS ! ! ! ]") + +@tvm.testing.requires_llvm +@pytest.mark.skip("skip due to AMX feature not avaliable yet") def test_amx_check_support(): amx_init = tvm.get_global_func("runtime.amx_init") amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") assert amx_init() assert amx_tileconfig(16, 64) + if __name__ == "__main__": - test_amx_check_support() - test_amx_u8s8s32_matmul_tensorize() \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index bd4e1b72c3cd..1d3a88d43edb 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -799,6 +799,48 @@ def test_dense_vnni(m, n, k): np.testing.assert_equal(out, ref) +@tvm.testing.requires_llvm +@pytest.mark.skip("skip due to AMX feature not avaliable yet") +def test_dense_amx_int8(): + data_shape = (32, 128) + weight_shape = (32, 128) + + for data_dtype in ["uint8", "int8"]: + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype="int8") + bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") + dense = relay.nn.dense(data, weight, out_dtype="int32") + out = relay.nn.bias_add(dense, bias) + mod = tvm.IRModule.from_expr(out) + + target = "llvm -mcpu=sapphirerapids" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target) + + asm = lib.lib.get_source("asm") + assert "tilezero" in asm + assert "tileloaddt1" in asm + assert "tdpbusd" in asm + assert "tilestored" in asm + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + a = np.random.uniform(1, 10, size=data_shape).astype(data_dtype) + b = np.random.uniform(1, 10, size=weight_shape).astype("int8") + c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") + + runtime.set_input("data", a) + runtime.set_input("weight", b) + runtime.set_input("bias", c) + runtime.run() + + out = runtime.get_output(0).numpy() + ref = np.dot(a.astype("int32"), b.transpose().astype("int32")) + c + + np.testing.assert_equal(out, ref) + + @pytest.mark.skip("Requires GFX10 AMDGPU") def test_dense_rocm_sdot4(): data_shape = (32, 96) From 3f19099f9132ba32d463827b04451801952af953 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Mon, 19 Dec 2022 04:32:34 -0500 Subject: [PATCH 04/21] add int8 dense kernel use amx tensorize --- python/tvm/relay/op/strategy/x86.py | 20 ++- python/tvm/topi/x86/dense.py | 150 ++++++++++++++++++++ python/tvm/topi/x86/dense_alter_op.py | 10 +- python/tvm/topi/x86/tensor_intrin.py | 195 ++++++++++++++++++-------- tests/python/contrib/test_amx.py | 34 +++-- tests/python/relay/test_op_level1.py | 42 ++++++ 6 files changed, 374 insertions(+), 77 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 7ff4dbc0ad1b..ba310cae21cd 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -25,7 +25,7 @@ from tvm.relay.ty import is_dynamic from tvm.target import Target from tvm.te import SpecializedCondition -from tvm.topi.x86.utils import target_has_vnni +from tvm.topi.x86.utils import target_has_vnni, target_has_amx from .. import op as _op from .generic import * @@ -591,9 +591,23 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def dense_pack_strategy_cpu(attrs, inputs, out_type, target): """dense_pack x86 strategy""" strategy = _op.OpStrategy() - + mcpu = Target.current().mcpu if ( - inputs[0].dtype == "uint8" + target_has_amx(mcpu) + and inputs[0].dtype == "uint8" + and inputs[1].dtype == "int8" + and out_type.dtype == "int32" + and attrs["weight_layout"] == "NC16n4c" + ): + strategy.add_implementation( + wrap_compute_dense(topi.x86.dense_amx_int8), + wrap_topi_schedule(topi.x86.schedule_dense_amx_int8), + name="dense_amx_int8.x86", + plevel=13, + ) + elif ( + target_has_vnni(mcpu) + and inputs[0].dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" and attrs["weight_layout"] == "NC16n4c" diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 65a803781a57..d7e971c116e7 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -27,6 +27,8 @@ from .. import generic, tag from ..utils import get_const_tuple, traverse_inline from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake +from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids +from .tensor_intrin import acc_32x32_int32_sapphirerapids from .utils import get_simd_32bit_lanes @@ -373,6 +375,154 @@ def _callback(op): return s +def dense_amx_int8_compute(cfg, data, packed_w, bias=None): + """Compute for uint8 x int8 -> int32 dense""" + m, k = data.shape + n_o, _, n_i, _ = packed_w.shape + ak = te.reduce_axis((0, k), name="k") + + C = te.compute( + (m, n_o * n_i), + lambda i, j: te.sum( + data[i, ak].astype("int32") + * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( + "int32" + ), + axis=ak, + ), + tag="dense_amx_int8", + attrs={"schedule_rule": "meta_schedule.dense_amx_int8"}, + ) + + if bias is not None: + C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) + + return C + + +def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True): + """Schedule dense compute using AMX TMUL instruction""" + # C: The output of GEMM + # O: The output of the fused op + def split_x(out): + default_x_split_factor1 = 32 + default_x_split_factor2 = 2 + default_x_split_factor3 = 2 + default_x_split_factor4 = 2 + a_x = s[out].op.axis[-2] + + if cfg.is_fallback: + a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1) + a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2) + a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3) + a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4) + return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi] + + cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: x.size[-1] == 32) + return cfg["tile_x"].apply(s, out, a_x) + + def split_y(out): + default_y_split_factor1 = 32 + default_y_split_factor2 = 4 + default_y_split_factor3 = 4 + default_y_split_factor4 = 4 + a_y = s[out].op.axis[-1] + + if cfg.is_fallback: + a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1) + a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2) + a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3) + a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4) + return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo] + + cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: y.size[-1] == 32) + return cfg["tile_y"].apply(s, out, a_y) + + def split_k(out, rd_axis): + default_k_split_factor1 = 128 + default_k_split_factor2 = 2 + default_k_split_factor3 = 2 + default_k_split_factor4 = 2 + + if cfg.is_fallback: + a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1) + a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2) + a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3) + a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4) + return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki] + + cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128) + return cfg["tile_k"].apply(s, out, rd_axis) + + a_x, a_y = C.op.axis + (a_k,) = C.op.reduce_axis + CF = s.cache_write(C, "amx.tmm") + + a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C) + a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C) + s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi) + + s[CF].compute_at(s[C], a_yo) + + (a_k_f,) = CF.op.reduce_axis + a_x_f, a_y_f = CF.op.axis + + a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) + + a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) + a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f) + s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) + + (m, k) = CF.op.input_tensors[0].shape + (n, c, n_i, c_i) = CF.op.input_tensors[1].shape + n = n * n_i + + s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k))) + s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n))) + + if C == O: + fused = s[O].fuse(a_x3, a_y3) + else: + a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O) + a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O) + + s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi) + s[O].vectorize(a_xi) + s[C].compute_at(s[O], a_y2) + + fused = s[O].fuse(a_y3, a_x3) + + if do_parallel: + s[O].parallel(fused) + + return s, fused + + +@autotvm.register_topi_compute("dense_amx_int8.x86") +def dense_amx_int8(cfg, data, weight, bias=None, out_dtype=None): + """Compute for uint8 x int8 -> int32 dense""" + if out_dtype is None: + out_dtype = data.dtype + assert len(weight.shape) == 4 + assert data.dtype == "uint8" and weight.dtype == "int8" + _, _, _, n_inner = get_const_tuple(weight.shape) # out_dim + assert n_inner == 4 + return dense_amx_int8_compute(cfg, data, weight, bias) + + +@autotvm.register_topi_schedule("dense_amx_int8.x86") +def schedule_dense_amx_int8(cfg, outs): + """Create a schedule for dense_amx_int8""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "dense_amx_int8" in op.tag: + dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + + def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib): """Compute matmul/dense using a BLAS library""" M, K = get_const_tuple(tensor_a.shape) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index fd2b184a87d2..9a2fd4b44ed7 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -25,13 +25,15 @@ from ..utils import get_const_tuple from ..nn import dense_alter_layout from .utils import target_has_vnni +from .utils import target_has_amx from .. import nn -def check_vnni_applicable(x, y, allow_padding=False): +def check_inst_applicable(x, y, allow_padding=False): mcpu = tvm.target.Target.current().mcpu + cpu_avai = target_has_vnni(mcpu) or target_has_amx(mcpu) return ( - target_has_vnni(mcpu) + cpu_avai and "int8" in x.dtype and "int8" in y.dtype and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0)) @@ -47,7 +49,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - if check_vnni_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": + if check_inst_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": weight_layout = "NC16n4c" return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) @@ -87,7 +89,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): def vnni_legalize(inputs, arg_types, op, attrs, need_expand=False): """Legalizes s8, s8 -> s32 GEMM op for VNNI.""" if ( - check_vnni_applicable(arg_types[0], arg_types[1], allow_padding=True) + check_inst_applicable(arg_types[0], arg_types[1], allow_padding=True) and arg_types[0].dtype == "int8" ): x, y = inputs diff --git a/python/tvm/topi/x86/tensor_intrin.py b/python/tvm/topi/x86/tensor_intrin.py index 90857dbae160..ffa32d74c467 100644 --- a/python/tvm/topi/x86/tensor_intrin.py +++ b/python/tvm/topi/x86/tensor_intrin.py @@ -349,13 +349,14 @@ def _instr(index): default_buffer_params=buffer_params, ) + def dot_32x128x32_u8s8s32_sapphirerapids(LDA): - """ + """ Int8 dot product by every 16x64 elements using AMX-TMUL Sapphire Rapids instructions. The tdpxxd instruction takes two tile of uint8 and int8 datatype -- data[16][64] and kernel[1][16][16][4] -- and computes a dot product of data[16][16] in int32 datatype. - (Physically, to efficiently leveraging the tile register, we constructing a 2x2 tiles + (Physically, to efficiently leveraging the tile register, we constructing a 2x2 tiles matmul which performs 32x128x32 in total) The pseudo code is as follows: @@ -371,13 +372,14 @@ def dot_32x128x32_u8s8s32_sapphirerapids(LDA): } Args: - LDA (int): the stride of the matrix A, which is uint8 type and - use it to determine memory strides of macro reduce axis. + LDA (int): the stride of the matrix A, which is uint8 type and use it to determine + memory strides of macro reduce axis. Returns ------- intrin : TensorIntrin - The Sapphire Rapids AMX-TMUL int8 tdpbusd TensorIntrin that can be used in tensorizing schedule + The Sapphire Rapids AMX-TMUL int8 tdpbusd TensorIntrin that can be used in tensorizing + schedule """ A = te.placeholder((32, 128), name="A", dtype="uint8") B = te.placeholder((2, 32, 16, 4), name="B", dtype="int8") @@ -393,99 +395,180 @@ def dot_32x128x32_u8s8s32_sapphirerapids(LDA): name="C", ) - BA = tvm.tir.decl_buffer(A.shape, A.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BA") - BB = tvm.tir.decl_buffer(B.shape, B.dtype, offset_factor=1, strides=[te.var("ldw"), te.var("ldw"), te.var("ldw"), 1], name="BB") - BC = tvm.tir.decl_buffer(C.shape, C.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BC",scope="amx.tmm") + BA = tvm.tir.decl_buffer( + A.shape, A.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BA" + ) + BB = tvm.tir.decl_buffer( + B.shape, + B.dtype, + offset_factor=1, + strides=[te.var("ldw"), te.var("ldw"), te.var("ldw"), 1], + name="BB", + ) + BC = tvm.tir.decl_buffer( + C.shape, C.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BC", scope="amx.tmm" + ) def intrin_func(ins, outs): - bufA = ins[0] + bufA = ins[0] bufB = ins[1] bufC = outs[0] assert LDA _strides_A = tvm.tir.const(LDA, dtype="uint64") - _strides_B_tile = tvm.tir.const(LDA/128, dtype="uint64") - + _strides_B_tile = tvm.tir.const(LDA / 128, dtype="uint64") + def init(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", - tvm.tir.const(1, "uint8"), - tvm.tir.const(0, dtype="uint8") )) # tile C 0 - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", - tvm.tir.const(1, "uint8"), - tvm.tir.const(1, dtype="uint8") )) # tile C 1 - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", - tvm.tir.const(1, "uint8"), - tvm.tir.const(2, dtype="uint8") )) # tile C 2 - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilezero", - tvm.tir.const(1, "uint8"), - tvm.tir.const(3, dtype="uint8") )) # tile C 3 + ib = tvm.tir.ir_builder.create() + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(0, dtype="uint8"), + ) + ) # tile C 0 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(1, dtype="uint8"), + ) + ) # tile C 1 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(2, dtype="uint8"), + ) + ) # tile C 2 + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilezero", + tvm.tir.const(1, "uint8"), + tvm.tir.const(3, dtype="uint8"), + ) + ) # tile C 3 return ib.get() - def body(): # load A, load B, dpbusd, store C + def body(): # load A, load B, dpbusd, store C ib = tvm.tir.ir_builder.create() - - for k_tile in range(2): # reduced data blocks - for n_acc in range(2): # broadcast data blocks - tmm_B_ = tvm.tir.const(n_acc+6, dtype="uint8") - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tileloaddt164", # load B: tmm6, tmm7 - tvm.tir.const(3, "uint8"), - tmm_B_, bufB.access_ptr("r", offset=64*16*(n_acc*2*_strides_B_tile + k_tile)), tvm.tir.const(64, dtype="uint64"))) - - for m_acc in range(2): # loaded data blocks - tmm_A_ = tvm.tir.const(m_acc+4, dtype="uint8") + + for k_tile in range(2): # reduced data blocks + for n_acc in range(2): # broadcast data blocks + tmm_B_ = tvm.tir.const(n_acc + 6, dtype="uint8") + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tileloaddt164", # load B: tmm6, tmm7 + tvm.tir.const(3, "uint8"), + tmm_B_, + bufB.access_ptr( + "r", offset=64 * 16 * (n_acc * 2 * _strides_B_tile + k_tile) + ), + tvm.tir.const(64, dtype="uint64"), + ) + ) + + for m_acc in range(2): # loaded data blocks + tmm_A_ = tvm.tir.const(m_acc + 4, dtype="uint8") if n_acc == 0: - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tileloaddt164", # load A: , tmm4, tmm5 - tvm.tir.const(3, "uint8"), - tmm_A_, bufA.access_ptr("r", offset=m_acc*16*_strides_A + k_tile*64), _strides_A)) + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tileloaddt164", # load A: , tmm4, tmm5 + tvm.tir.const(3, "uint8"), + tmm_A_, + bufA.access_ptr( + "r", offset=m_acc * 16 * _strides_A + k_tile * 64 + ), + _strides_A, + ) + ) tmm_C_ = tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8") - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tdpbusd", - tvm.tir.const(3, "uint8"), - tmm_C_, tmm_A_, tmm_B_)) # tdpxxd + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tdpbusd", + tvm.tir.const(3, "uint8"), + tmm_C_, + tmm_A_, + tmm_B_, + ) + ) # tdpxxd return ib.get() # body, reset, store - return body(), init(), body(), + return ( + body(), + init(), + body(), + ) return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + def acc_32x32_int32_sapphirerapids(LDC): - """ + """ Store the accumulated tile register in scope amx.tmm to global memory. (tmm0, tmm1, tmm2, tmm3 --> global 4 tiles) Args: - LDC (int): the stride of the matrix C, which is int32 type and - use it to determine memory strides. + LDC (int): the stride of the matrix C, which is int32 type and use it to + determine memory strides. Returns ------- intrin : TensorIntrin - The Sapphirerapids AMX-TMUL int8 tilestored64 TensorIntrin that can be used in tensorizing schedule + The Sapphirerapids AMX-TMUL int8 tilestored64 TensorIntrin that can be used + in tensorizing schedule """ A = te.placeholder((32, 32), name="A", dtype="int32") - bufA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="amx.tmm", name="a_buffer", offset_factor=1, strides=[te.var("ldw"), 1]) + bufA = tvm.tir.decl_buffer( + A.shape, + A.dtype, + scope="amx.tmm", + name="a_buffer", + offset_factor=1, + strides=[te.var("ldw"), 1], + ) C = te.compute((32, 32), lambda i, j: A[i, j], name="C") - bufC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="global", name="c_buffer", offset_factor=1, strides=[te.var("ldw"), 1]) + bufC = tvm.tir.decl_buffer( + C.shape, + C.dtype, + scope="global", + name="c_buffer", + offset_factor=1, + strides=[te.var("ldw"), 1], + ) assert LDC - _strides_C = tvm.tir.const(4*LDC, dtype="uint64") + _strides_C = tvm.tir.const(4 * LDC, dtype="uint64") def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() bufA = ins[0] bufC = outs[0] - for n_acc in range(2): # broadcast data blocks - for m_acc in range(2): # loaded data blocks - ib.emit(tvm.tir.call_llvm_intrin("int32", "llvm.x86.tilestored64", - tvm.tir.const(3, "uint8"), - tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8"), - bufC.access_ptr( "w", offset=n_acc*16 + m_acc*16*_strides_C/4), _strides_C)) + for n_acc in range(2): # broadcast data blocks + for m_acc in range(2): # loaded data blocks + ib.emit( + tvm.tir.call_llvm_intrin( + "int32", + "llvm.x86.tilestored64", + tvm.tir.const(3, "uint8"), + tvm.tir.const(m_acc * 2 + n_acc, dtype="uint8"), + bufC.access_ptr("w", offset=n_acc * 16 + m_acc * 16 * _strides_C / 4), + _strides_C, + ) + ) return ib.get() - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: bufA, C: bufC}) \ No newline at end of file + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: bufA, C: bufC}) diff --git a/tests/python/contrib/test_amx.py b/tests/python/contrib/test_amx.py index a49b7f30edb6..30da7e56fb8d 100644 --- a/tests/python/contrib/test_amx.py +++ b/tests/python/contrib/test_amx.py @@ -17,6 +17,8 @@ # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition import tvm +from tvm import relay + from tvm import te import tvm.testing from tvm.topi.x86.tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids @@ -24,25 +26,26 @@ import numpy as np import pytest + @tvm.testing.requires_llvm -@pytest.mark.skip("skip because AMX feature not avaliable yet") +@pytest.mark.skip("skip due to AMX feature not avaliable yet") def test_amx_u8s8s32_matmul_tensorize(): m = 1024 k = 1024 n = 1024 - + # --------------------------Config--------------------------- # Skip this test if "-mcpu=sapphirerapids" not supported by LLVM < 12.0 - target="llvm -mcpu=sapphirerapids" + target = "llvm -mcpu=sapphirerapids" dev = tvm.device(target, 0) if not tvm.testing.device_enabled(target): print("skip because %s is not enabled..." % target) return - + amx_init = tvm.get_global_func("runtime.amx_init") amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") assert amx_init() - assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. + assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. # --------------------------Compute-------------------------- X = te.placeholder((m, k), name="X", dtype="uint8") ak = te.reduce_axis((0, k), name="k") @@ -52,7 +55,9 @@ def test_amx_u8s8s32_matmul_tensorize(): (m, n), lambda i, j: te.sum( X[i, ak].astype("int32") - * packedW[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype("int32"), + * packedW[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( + "int32" + ), axis=ak, ), name="F", @@ -69,14 +74,13 @@ def test_amx_u8s8s32_matmul_tensorize(): s[C].reorder(a_xo, a_yo, a_xi, a_yi) s[CF].compute_at(s[C], a_yo) - (a_k_f,) = CF.op.reduce_axis + (a_k_f,) = CF.op.reduce_axis a_x_f, a_y_f = CF.op.axis a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) a_ko_f, a_ki_f = s[CF].split(a_k_f, factor=128) - s[CF].reorder(a_ko_f, a_xo_f, a_yo_f, - a_ki_f, a_xi_f, a_yi_f) + s[CF].reorder(a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=k)) s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=n)) @@ -87,7 +91,7 @@ def test_amx_u8s8s32_matmul_tensorize(): assert "tileloaddt1" in asm assert "tdpbusd" in asm assert "tilestored" in asm - + # ----------------------- verify correctness -------------------------------- # generate the plain data a = np.random.uniform(1, 10, size=(m, k)).astype("uint8") @@ -98,7 +102,7 @@ def test_amx_u8s8s32_matmul_tensorize(): # from plain data to blocked data(NC16n4c) for i_n in range(n): for i_k in range(k): - packW[i_n//16][i_k//4][i_n%16][i_k%4] = b[i_n][i_k] + packW[i_n // 16][i_k // 4][i_n % 16][i_k % 4] = b[i_n][i_k] x = tvm.nd.array(a, dev) w = tvm.nd.array(packW, dev) @@ -107,14 +111,16 @@ def test_amx_u8s8s32_matmul_tensorize(): result = t_evaluator(x, w, y) print(result) tvm.testing.assert_allclose(y.numpy(), np.dot(a.astype("int32"), b.T.astype("int32")), rtol=0) - print("[TEST PASS ! ! ! ]") + +@tvm.testing.requires_llvm +@pytest.mark.skip("skip due to AMX feature not avaliable yet") def test_amx_check_support(): amx_init = tvm.get_global_func("runtime.amx_init") amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") assert amx_init() assert amx_tileconfig(16, 64) + if __name__ == "__main__": - test_amx_check_support() - test_amx_u8s8s32_matmul_tensorize() \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index bd4e1b72c3cd..1d3a88d43edb 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -799,6 +799,48 @@ def test_dense_vnni(m, n, k): np.testing.assert_equal(out, ref) +@tvm.testing.requires_llvm +@pytest.mark.skip("skip due to AMX feature not avaliable yet") +def test_dense_amx_int8(): + data_shape = (32, 128) + weight_shape = (32, 128) + + for data_dtype in ["uint8", "int8"]: + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype="int8") + bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") + dense = relay.nn.dense(data, weight, out_dtype="int32") + out = relay.nn.bias_add(dense, bias) + mod = tvm.IRModule.from_expr(out) + + target = "llvm -mcpu=sapphirerapids" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target) + + asm = lib.lib.get_source("asm") + assert "tilezero" in asm + assert "tileloaddt1" in asm + assert "tdpbusd" in asm + assert "tilestored" in asm + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + a = np.random.uniform(1, 10, size=data_shape).astype(data_dtype) + b = np.random.uniform(1, 10, size=weight_shape).astype("int8") + c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") + + runtime.set_input("data", a) + runtime.set_input("weight", b) + runtime.set_input("bias", c) + runtime.run() + + out = runtime.get_output(0).numpy() + ref = np.dot(a.astype("int32"), b.transpose().astype("int32")) + c + + np.testing.assert_equal(out, ref) + + @pytest.mark.skip("Requires GFX10 AMDGPU") def test_dense_rocm_sdot4(): data_shape = (32, 96) From c53c39445307670c27d18ac1c858e5a5299acf5f Mon Sep 17 00:00:00 2001 From: Qianshui Date: Mon, 19 Dec 2022 04:35:55 -0500 Subject: [PATCH 05/21] add amx init() and config() for dense test case --- tests/python/relay/test_op_level1.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 1d3a88d43edb..0993a22f4165 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -766,6 +766,11 @@ def test_dense_vnni(m, n, k): data_shape = (m, k) weight_shape = (n, k) + amx_init = tvm.get_global_func("runtime.amx_init") + amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") + assert amx_init() + assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. + for data_dtype in ["uint8", "int8"]: data = relay.var("data", shape=data_shape, dtype=data_dtype) weight = relay.var("weight", shape=weight_shape, dtype="int8") From 79d6636a89e17c56e3b957cb07878c445f85c546 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Mon, 19 Dec 2022 04:44:08 -0500 Subject: [PATCH 06/21] correct the amx config --- tests/python/relay/test_op_level1.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 0993a22f4165..9f31acfa6d7f 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -766,11 +766,6 @@ def test_dense_vnni(m, n, k): data_shape = (m, k) weight_shape = (n, k) - amx_init = tvm.get_global_func("runtime.amx_init") - amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") - assert amx_init() - assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. - for data_dtype in ["uint8", "int8"]: data = relay.var("data", shape=data_shape, dtype=data_dtype) weight = relay.var("weight", shape=weight_shape, dtype="int8") @@ -810,6 +805,11 @@ def test_dense_amx_int8(): data_shape = (32, 128) weight_shape = (32, 128) + amx_init = tvm.get_global_func("runtime.amx_init") + amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") + assert amx_init() + assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. + for data_dtype in ["uint8", "int8"]: data = relay.var("data", shape=data_shape, dtype=data_dtype) weight = relay.var("weight", shape=weight_shape, dtype="int8") From b866673696f556d2e06a71c277c570ff1fb0fcd2 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Mon, 19 Dec 2022 05:31:07 -0500 Subject: [PATCH 07/21] fix lint. --- python/tvm/topi/x86/dense.py | 4 ++-- python/tvm/topi/x86/tensor_intrin.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index d7e971c116e7..24b391e9f9e7 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,too-many-locals,unused-variable -# pylint: disable=no-value-for-parameter +# pylint: disable=invalid-name,too-many-locals,unused-argument +# pylint: disable=no-value-for-parameter,unused-variable """x86 dense operators""" from __future__ import absolute_import as _abs diff --git a/python/tvm/topi/x86/tensor_intrin.py b/python/tvm/topi/x86/tensor_intrin.py index ffa32d74c467..3b83fecbf552 100644 --- a/python/tvm/topi/x86/tensor_intrin.py +++ b/python/tvm/topi/x86/tensor_intrin.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Core kernel of dot product of 4 Int8 operations""" -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unused-variable import tvm from tvm import te import tvm.target.codegen @@ -409,7 +409,7 @@ def dot_32x128x32_u8s8s32_sapphirerapids(LDA): C.shape, C.dtype, offset_factor=1, strides=[te.var("ldw"), 1], name="BC", scope="amx.tmm" ) - def intrin_func(ins, outs): + def intrin_func(ins, outs): # pylint: disable=unused-variable bufA = ins[0] bufB = ins[1] bufC = outs[0] @@ -552,7 +552,7 @@ def acc_32x32_int32_sapphirerapids(LDC): assert LDC _strides_C = tvm.tir.const(4 * LDC, dtype="uint64") - def intrin_func(ins, outs): + def intrin_func(ins, outs): # pylint: disable=unused-variable ib = tvm.tir.ir_builder.create() bufA = ins[0] bufC = outs[0] From 48fa37ea7bfdbaf5eed5a04e01c47ae44bbb7c37 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Mon, 19 Dec 2022 11:08:49 -0500 Subject: [PATCH 08/21] fix dense schedule --- python/tvm/topi/x86/dense.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 24b391e9f9e7..0695b7cfd161 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -488,9 +488,8 @@ def split_k(out, rd_axis): s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi) s[O].vectorize(a_xi) - s[C].compute_at(s[O], a_y2) - fused = s[O].fuse(a_y3, a_x3) + fused = s[O].fuse(a_x3, a_y3) if do_parallel: s[O].parallel(fused) From dd1eb2498a4e2c867a16d7d9ccea010c396e10ae Mon Sep 17 00:00:00 2001 From: Qianshui Date: Thu, 22 Dec 2022 21:55:59 -0500 Subject: [PATCH 09/21] remove operation of signal stack --- cmake/config.cmake | 3 +++ src/runtime/contrib/amx/amx_config.cc | 8 -------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index 679f5c459e87..3189815d237c 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -174,6 +174,9 @@ set(USE_MKL OFF) # - OFF: Disable DNNL set(USE_DNNL OFF) +# Whether use Intel AMX instructions. +set(USE_AMX OFF) + # Whether use OpenMP thread pool, choices: gnu, intel # Note: "gnu" uses gomp library, "intel" uses iomp5 library set(USE_OPENMP none) diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc index f04126405a03..2e034bd478b5 100644 --- a/src/runtime/contrib/amx/amx_config.cc +++ b/src/runtime/contrib/amx/amx_config.cc @@ -90,14 +90,6 @@ TVM_REGISTER_GLOBAL("runtime.amx_tileconfig").set_body([](TVMArgs args, TVMRetVa // register a global packed function in c++,to init the system for AMX config TVM_REGISTER_GLOBAL("runtime.amx_init").set_body([](TVMArgs args, TVMRetValue* rv) { - // -----------Enlarge the signal stack in linux---------------------- - char largestack[15535]; // SIGSTKSZ=8192, MINSIGSTKSZ=2048, 15535 - stack_t ss; - ss.ss_sp = largestack; - ss.ss_size = sizeof(largestack); - ss.ss_flags = 0; - if (sigaltstack(&ss, NULL)) exit(-1); - // -----------Detect and request for AMX control---------------------- uint64_t bitmask = 0; int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); From 73f45efda983cb85a755eaee217e22ebcc32ca80 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Sun, 25 Dec 2022 08:37:56 -0500 Subject: [PATCH 10/21] fix nit --- python/tvm/topi/x86/dense_alter_op.py | 4 ++-- tests/python/contrib/test_gemm_acc32_vnni.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 9a2fd4b44ed7..2cb46b8291fb 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -31,9 +31,9 @@ def check_inst_applicable(x, y, allow_padding=False): mcpu = tvm.target.Target.current().mcpu - cpu_avai = target_has_vnni(mcpu) or target_has_amx(mcpu) + simd_avai = target_has_vnni(mcpu) or target_has_amx(mcpu) return ( - cpu_avai + simd_avai and "int8" in x.dtype and "int8" in y.dtype and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0)) diff --git a/tests/python/contrib/test_gemm_acc32_vnni.py b/tests/python/contrib/test_gemm_acc32_vnni.py index 9121da21476a..9cec823cc58a 100644 --- a/tests/python/contrib/test_gemm_acc32_vnni.py +++ b/tests/python/contrib/test_gemm_acc32_vnni.py @@ -115,5 +115,5 @@ def verify(target="llvm -mcpu=cascadelake"): # The test requires Cascade Lake and newer Intel machines to generate the # correct AVX512 VNNI instruction. So, disabling the test. - # test_fc_int8_acc32()`` + # test_fc_int8_acc32() pass From b9210524302a56e0935cb4667b93b80cb277aa85 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Thu, 29 Dec 2022 02:07:49 -0500 Subject: [PATCH 11/21] unified amx and vnni compute, remove dup one --- python/tvm/relay/op/strategy/x86.py | 23 ++----- python/tvm/topi/x86/dense.py | 103 +++++++++------------------- 2 files changed, 38 insertions(+), 88 deletions(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index ba310cae21cd..a99732693c52 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -591,33 +591,18 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def dense_pack_strategy_cpu(attrs, inputs, out_type, target): """dense_pack x86 strategy""" strategy = _op.OpStrategy() - mcpu = Target.current().mcpu if ( - target_has_amx(mcpu) - and inputs[0].dtype == "uint8" + inputs[0].dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" and attrs["weight_layout"] == "NC16n4c" ): strategy.add_implementation( - wrap_compute_dense(topi.x86.dense_amx_int8), - wrap_topi_schedule(topi.x86.schedule_dense_amx_int8), - name="dense_amx_int8.x86", + wrap_compute_dense(topi.x86.dense_int8), + wrap_topi_schedule(topi.x86.schedule_dense_int8), + name="dense_int8.x86", plevel=13, ) - elif ( - target_has_vnni(mcpu) - and inputs[0].dtype == "uint8" - and inputs[1].dtype == "int8" - and out_type.dtype == "int32" - and attrs["weight_layout"] == "NC16n4c" - ): - strategy.add_implementation( - wrap_compute_dense(topi.x86.dense_vnni), - wrap_topi_schedule(topi.x86.schedule_dense_vnni), - name="dense_vnni.x86", - plevel=12, - ) else: strategy.add_implementation( wrap_compute_dense(topi.x86.dense_pack), diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 0695b7cfd161..c8faf9608098 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -29,7 +29,7 @@ from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids from .tensor_intrin import acc_32x32_int32_sapphirerapids -from .utils import get_simd_32bit_lanes +from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx def _schedule_dense_pack_template(cfg, s, C, O): @@ -280,7 +280,36 @@ def _callback(op): return s -def dense_vnni_compute(cfg, X, packed_w, bias=None): +@autotvm.register_topi_compute("dense_int8.x86") +def dense_int8(cfg, data, weight, bias=None, out_dtype=None): + """Compute for uint8 x int8 -> int32 dense""" + if out_dtype is None: + out_dtype = data.dtype + assert len(weight.shape) == 4 + assert data.dtype == "uint8" and weight.dtype == "int8" + _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim + assert n_inner == 16 and k_inner == 4 + return dense_int8_compute(cfg, data, weight, bias) + + +@autotvm.register_topi_schedule("dense_int8.x86") +def schedule_dense_int8(cfg, outs): + """Create a schedule for dense__int8""" + s = te.create_schedule([x.op for x in outs]) + mcpu = tvm.target.Target.current().mcpu + + def _callback(op): + if "dense_int8" in op.tag: + if target_has_amx(mcpu): + dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) + elif target_has_vnni(mcpu): + dense_vnni_schedule(cfg, s, op.output(0), outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def dense_int8_compute(cfg, X, packed_w, bias=None): """Compute for uint8 x int8 -> int32 dense""" m, k = X.shape n_o, _, n_i, _ = packed_w.shape @@ -295,16 +324,13 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None): ), axis=ak, ), - tag="dense_vnni", - attrs={"schedule_rule": "dense_vnni"}, + tag="dense_int8", + attrs={"schedule_rule": "dense_int8"}, ) if bias is not None: C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) - a_y, _ = C.op.axis - cfg.define_split("tile_y", a_y, num_outputs=2) - return C @@ -319,6 +345,7 @@ def split_y(out): if cfg.is_fallback: return s[out].split(a_y, factor=default_y_split_factor) + cfg.define_split("tile_y", a_y, num_outputs=2) return cfg["tile_y"].apply(s, out, a_y) (a_k,) = C.op.reduce_axis @@ -350,56 +377,6 @@ def split_y(out): return s, fused -@autotvm.register_topi_compute("dense_vnni.x86") -def dense_vnni(cfg, data, weight, bias=None, out_dtype=None): - """Compute for uint8 x int8 -> int32 dense""" - if out_dtype is None: - out_dtype = data.dtype - assert len(weight.shape) == 4 - assert data.dtype == "uint8" and weight.dtype == "int8" - _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim - assert n_inner == 16 and k_inner == 4 - return dense_vnni_compute(cfg, data, weight, bias) - - -@autotvm.register_topi_schedule("dense_vnni.x86") -def schedule_dense_vnni(cfg, outs): - """Create a schedule for dense_vnni""" - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if "dense_vnni" in op.tag: - dense_vnni_schedule(cfg, s, op.output(0), outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s - - -def dense_amx_int8_compute(cfg, data, packed_w, bias=None): - """Compute for uint8 x int8 -> int32 dense""" - m, k = data.shape - n_o, _, n_i, _ = packed_w.shape - ak = te.reduce_axis((0, k), name="k") - - C = te.compute( - (m, n_o * n_i), - lambda i, j: te.sum( - data[i, ak].astype("int32") - * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( - "int32" - ), - axis=ak, - ), - tag="dense_amx_int8", - attrs={"schedule_rule": "meta_schedule.dense_amx_int8"}, - ) - - if bias is not None: - C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) - - return C - - def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True): """Schedule dense compute using AMX TMUL instruction""" # C: The output of GEMM @@ -497,18 +474,6 @@ def split_k(out, rd_axis): return s, fused -@autotvm.register_topi_compute("dense_amx_int8.x86") -def dense_amx_int8(cfg, data, weight, bias=None, out_dtype=None): - """Compute for uint8 x int8 -> int32 dense""" - if out_dtype is None: - out_dtype = data.dtype - assert len(weight.shape) == 4 - assert data.dtype == "uint8" and weight.dtype == "int8" - _, _, _, n_inner = get_const_tuple(weight.shape) # out_dim - assert n_inner == 4 - return dense_amx_int8_compute(cfg, data, weight, bias) - - @autotvm.register_topi_schedule("dense_amx_int8.x86") def schedule_dense_amx_int8(cfg, outs): """Create a schedule for dense_amx_int8""" From e74936014f8e3077e8335e558461f7e291ad0598 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Thu, 29 Dec 2022 02:24:38 -0500 Subject: [PATCH 12/21] fix lint --- python/tvm/relay/op/strategy/x86.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index a99732693c52..4585809f63e1 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -25,7 +25,7 @@ from tvm.relay.ty import is_dynamic from tvm.target import Target from tvm.te import SpecializedCondition -from tvm.topi.x86.utils import target_has_vnni, target_has_amx +from tvm.topi.x86.utils import target_has_vnni from .. import op as _op from .generic import * From 5718a059c69972cf71ea082a3303b5c29fa2d21f Mon Sep 17 00:00:00 2001 From: Qianshui Date: Sat, 31 Dec 2022 09:18:12 -0500 Subject: [PATCH 13/21] adopt to x86 int8 dense compute method; --- .../unittest/test_meta_schedule_vnni_integration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_vnni_integration.py index 3bbe916472f5..7caec83a1267 100644 --- a/tests/python/unittest/test_meta_schedule_vnni_integration.py +++ b/tests/python/unittest/test_meta_schedule_vnni_integration.py @@ -47,7 +47,7 @@ def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: if dense_block is None: assert has_block(sch, "compute") dense_block = sch.get_block("compute") - assert "dense_vnni" in sch.get(dense_block).annotations["schedule_rule"] + assert "dense_int8" in sch.get(dense_block).annotations["schedule_rule"] post_blocks = sch.get_consumers(dense_block) if len(post_blocks) > 0: @@ -176,12 +176,12 @@ def test_vnni_schedule_fn_tune(): C = te.compute( ... - attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"}, + attrs={"schedule_rule": "meta_schedule.x86.dense_int8"}, ) When the MetaSchedule encounters a TensorIR block with the "schedule_rule" annotation, it looks up the packed func registry for a function that is associated with the given schedule - rule key ("meta_schedule.x86.dense_vnni" in this example). The signature of such custom + rule key ("meta_schedule.x86.dense_int8" in this example). The signature of such custom schedule functions must be (tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule]. @@ -195,7 +195,7 @@ def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): _schedule_dense(m=None, do_tune=True)(sch, dense_block) return [sch] - register_func("meta_schedule.x86.dense_vnni", schedule_rule_dense_vnni) + register_func("meta_schedule.x86.dense_int8", schedule_rule_dense_vnni) m, n, k = 1024, 1024, 1024 target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake -num-cores=4") From 581331a5645fd45f1424166496ad4e7c091f3f8d Mon Sep 17 00:00:00 2001 From: Qianshui Date: Sun, 1 Jan 2023 07:27:25 -0500 Subject: [PATCH 14/21] Revert "adopt to x86 int8 dense compute method;" This reverts commit 5718a059c69972cf71ea082a3303b5c29fa2d21f. --- .../unittest/test_meta_schedule_vnni_integration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_vnni_integration.py b/tests/python/unittest/test_meta_schedule_vnni_integration.py index 7caec83a1267..3bbe916472f5 100644 --- a/tests/python/unittest/test_meta_schedule_vnni_integration.py +++ b/tests/python/unittest/test_meta_schedule_vnni_integration.py @@ -47,7 +47,7 @@ def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: if dense_block is None: assert has_block(sch, "compute") dense_block = sch.get_block("compute") - assert "dense_int8" in sch.get(dense_block).annotations["schedule_rule"] + assert "dense_vnni" in sch.get(dense_block).annotations["schedule_rule"] post_blocks = sch.get_consumers(dense_block) if len(post_blocks) > 0: @@ -176,12 +176,12 @@ def test_vnni_schedule_fn_tune(): C = te.compute( ... - attrs={"schedule_rule": "meta_schedule.x86.dense_int8"}, + attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"}, ) When the MetaSchedule encounters a TensorIR block with the "schedule_rule" annotation, it looks up the packed func registry for a function that is associated with the given schedule - rule key ("meta_schedule.x86.dense_int8" in this example). The signature of such custom + rule key ("meta_schedule.x86.dense_vnni" in this example). The signature of such custom schedule functions must be (tir.schedule.Schedule, tir.schedule.BlockRV) -> [tir.schedule.Schedule]. @@ -195,7 +195,7 @@ def schedule_rule_dense_vnni(sch: Schedule, dense_block: BlockRV): _schedule_dense(m=None, do_tune=True)(sch, dense_block) return [sch] - register_func("meta_schedule.x86.dense_int8", schedule_rule_dense_vnni) + register_func("meta_schedule.x86.dense_vnni", schedule_rule_dense_vnni) m, n, k = 1024, 1024, 1024 target = tvm.target.Target("llvm -keys=x86,cpu -mcpu=cascadelake -num-cores=4") From 2bda03e0ed67d86a90511ee9eb7afaa2215bad17 Mon Sep 17 00:00:00 2001 From: Qianshui Date: Sun, 1 Jan 2023 07:34:40 -0500 Subject: [PATCH 15/21] restore schedule ruls specially for ms dense_vnni --- python/tvm/topi/x86/dense.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index c8faf9608098..5f5990b32b26 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -325,7 +325,7 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_int8", - attrs={"schedule_rule": "dense_int8"}, + attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"}, ) if bias is not None: From c2e9f26fd9d84ce75e9a0c1474df1b7e0b9ff4f3 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Wed, 4 Jan 2023 11:54:53 +0000 Subject: [PATCH 16/21] add vnni ms target attributes --- .../qemu-hack/qemu-system-arm | 44 ++++++++++++++++++- .../qemu-hack/qemu-system-riscv32 | 44 ++++++++++++++++++- .../qemu-hack/qemu-system-riscv64 | 44 ++++++++++++++++++- .../qemu-hack/qemu-system-xilinx-aarch64 | 44 ++++++++++++++++++- apps/sgx/.rustfmt.toml | 32 +++++++++++++- python/tvm/topi/x86/dense.py | 7 ++- 6 files changed, 209 insertions(+), 6 deletions(-) mode change 120000 => 100755 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm mode change 120000 => 100755 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 mode change 120000 => 100755 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 mode change 120000 => 100755 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 mode change 120000 => 100644 apps/sgx/.rustfmt.toml diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm deleted file mode 120000 index ebbc8ad5ad9d..000000000000 --- a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm +++ /dev/null @@ -1 +0,0 @@ -qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm new file mode 100755 index 000000000000..2d350698edb9 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm @@ -0,0 +1,43 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Zephyr insists on running qemu with a -pidfile option, but that option doesn't appear to +# work given the way we've configured docker (the underlying filesystem doesn't support the +# file locking it needs to). This script strips any -pidfile option, then invokes qemu. + +ARGS=( "$(basename $0)" ) + +if [ "${QEMU_BIN_PATH}" != "" ]; then + ARGS=${QEMU_BIN_PATH}/${ARGS} +fi + +while [ "$#" -gt 0 ]; do + if [ "$1" == "-pidfile" ]; then + shift + else + ARGS=( "${ARGS[@]}" "$1" ) + fi + shift +done + +# For debugging +if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then + ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) +fi + +"${ARGS[@]}" diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 deleted file mode 120000 index ebbc8ad5ad9d..000000000000 --- a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 +++ /dev/null @@ -1 +0,0 @@ -qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 new file mode 100755 index 000000000000..2d350698edb9 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 @@ -0,0 +1,43 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Zephyr insists on running qemu with a -pidfile option, but that option doesn't appear to +# work given the way we've configured docker (the underlying filesystem doesn't support the +# file locking it needs to). This script strips any -pidfile option, then invokes qemu. + +ARGS=( "$(basename $0)" ) + +if [ "${QEMU_BIN_PATH}" != "" ]; then + ARGS=${QEMU_BIN_PATH}/${ARGS} +fi + +while [ "$#" -gt 0 ]; do + if [ "$1" == "-pidfile" ]; then + shift + else + ARGS=( "${ARGS[@]}" "$1" ) + fi + shift +done + +# For debugging +if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then + ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) +fi + +"${ARGS[@]}" diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 deleted file mode 120000 index ebbc8ad5ad9d..000000000000 --- a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 +++ /dev/null @@ -1 +0,0 @@ -qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 new file mode 100755 index 000000000000..2d350698edb9 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 @@ -0,0 +1,43 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Zephyr insists on running qemu with a -pidfile option, but that option doesn't appear to +# work given the way we've configured docker (the underlying filesystem doesn't support the +# file locking it needs to). This script strips any -pidfile option, then invokes qemu. + +ARGS=( "$(basename $0)" ) + +if [ "${QEMU_BIN_PATH}" != "" ]; then + ARGS=${QEMU_BIN_PATH}/${ARGS} +fi + +while [ "$#" -gt 0 ]; do + if [ "$1" == "-pidfile" ]; then + shift + else + ARGS=( "${ARGS[@]}" "$1" ) + fi + shift +done + +# For debugging +if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then + ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) +fi + +"${ARGS[@]}" diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 deleted file mode 120000 index ebbc8ad5ad9d..000000000000 --- a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 +++ /dev/null @@ -1 +0,0 @@ -qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 new file mode 100755 index 000000000000..2d350698edb9 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 @@ -0,0 +1,43 @@ +#!/bin/bash -e +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Zephyr insists on running qemu with a -pidfile option, but that option doesn't appear to +# work given the way we've configured docker (the underlying filesystem doesn't support the +# file locking it needs to). This script strips any -pidfile option, then invokes qemu. + +ARGS=( "$(basename $0)" ) + +if [ "${QEMU_BIN_PATH}" != "" ]; then + ARGS=${QEMU_BIN_PATH}/${ARGS} +fi + +while [ "$#" -gt 0 ]; do + if [ "$1" == "-pidfile" ]; then + shift + else + ARGS=( "${ARGS[@]}" "$1" ) + fi + shift +done + +# For debugging +if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then + ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) +fi + +"${ARGS[@]}" diff --git a/apps/sgx/.rustfmt.toml b/apps/sgx/.rustfmt.toml deleted file mode 120000 index 27139e42a3f2..000000000000 --- a/apps/sgx/.rustfmt.toml +++ /dev/null @@ -1 +0,0 @@ -../../rust/.rustfmt.toml \ No newline at end of file diff --git a/apps/sgx/.rustfmt.toml b/apps/sgx/.rustfmt.toml new file mode 100644 index 000000000000..3c51bb384c68 --- /dev/null +++ b/apps/sgx/.rustfmt.toml @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +max_width = 100 +hard_tabs = false +tab_spaces = 4 +newline_style = "Auto" +use_small_heuristics = "Default" +reorder_imports = true +reorder_modules = true +remove_nested_parens = true +fn_args_layout = "Tall" +edition = "2018" +merge_derives = true +use_try_shorthand = false +use_field_init_shorthand = false +force_explicit_abi = true diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 5f5990b32b26..ada19d598cdf 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -314,6 +314,11 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") + mcpu = tvm.target.Target.current().mcpu + if target_has_vnni(mcpu): + target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"} + else: + target_attr = None C = te.compute( (m, n_o * n_i), @@ -325,7 +330,7 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_int8", - attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"}, + attrs=target_attr, ) if bias is not None: From 4469fd9436bd8155572ac38aef081de8744a4fc4 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Thu, 5 Jan 2023 10:07:21 +0000 Subject: [PATCH 17/21] remove the misoperations --- .../qemu-hack/qemu-system-arm | 43 ------------------- .../qemu-hack/qemu-system-riscv32 | 43 ------------------- .../qemu-hack/qemu-system-riscv64 | 43 ------------------- .../qemu-hack/qemu-system-xilinx-aarch64 | 43 ------------------- apps/sgx/.rustfmt.toml | 31 ------------- 5 files changed, 203 deletions(-) delete mode 100755 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm delete mode 100755 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 delete mode 100755 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 delete mode 100755 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 delete mode 100644 apps/sgx/.rustfmt.toml diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm deleted file mode 100755 index 2d350698edb9..000000000000 --- a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -e -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Zephyr insists on running qemu with a -pidfile option, but that option doesn't appear to -# work given the way we've configured docker (the underlying filesystem doesn't support the -# file locking it needs to). This script strips any -pidfile option, then invokes qemu. - -ARGS=( "$(basename $0)" ) - -if [ "${QEMU_BIN_PATH}" != "" ]; then - ARGS=${QEMU_BIN_PATH}/${ARGS} -fi - -while [ "$#" -gt 0 ]; do - if [ "$1" == "-pidfile" ]; then - shift - else - ARGS=( "${ARGS[@]}" "$1" ) - fi - shift -done - -# For debugging -if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then - ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) -fi - -"${ARGS[@]}" diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 deleted file mode 100755 index 2d350698edb9..000000000000 --- a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -e -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Zephyr insists on running qemu with a -pidfile option, but that option doesn't appear to -# work given the way we've configured docker (the underlying filesystem doesn't support the -# file locking it needs to). This script strips any -pidfile option, then invokes qemu. - -ARGS=( "$(basename $0)" ) - -if [ "${QEMU_BIN_PATH}" != "" ]; then - ARGS=${QEMU_BIN_PATH}/${ARGS} -fi - -while [ "$#" -gt 0 ]; do - if [ "$1" == "-pidfile" ]; then - shift - else - ARGS=( "${ARGS[@]}" "$1" ) - fi - shift -done - -# For debugging -if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then - ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) -fi - -"${ARGS[@]}" diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 deleted file mode 100755 index 2d350698edb9..000000000000 --- a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -e -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Zephyr insists on running qemu with a -pidfile option, but that option doesn't appear to -# work given the way we've configured docker (the underlying filesystem doesn't support the -# file locking it needs to). This script strips any -pidfile option, then invokes qemu. - -ARGS=( "$(basename $0)" ) - -if [ "${QEMU_BIN_PATH}" != "" ]; then - ARGS=${QEMU_BIN_PATH}/${ARGS} -fi - -while [ "$#" -gt 0 ]; do - if [ "$1" == "-pidfile" ]; then - shift - else - ARGS=( "${ARGS[@]}" "$1" ) - fi - shift -done - -# For debugging -if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then - ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) -fi - -"${ARGS[@]}" diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 deleted file mode 100755 index 2d350698edb9..000000000000 --- a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -e -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Zephyr insists on running qemu with a -pidfile option, but that option doesn't appear to -# work given the way we've configured docker (the underlying filesystem doesn't support the -# file locking it needs to). This script strips any -pidfile option, then invokes qemu. - -ARGS=( "$(basename $0)" ) - -if [ "${QEMU_BIN_PATH}" != "" ]; then - ARGS=${QEMU_BIN_PATH}/${ARGS} -fi - -while [ "$#" -gt 0 ]; do - if [ "$1" == "-pidfile" ]; then - shift - else - ARGS=( "${ARGS[@]}" "$1" ) - fi - shift -done - -# For debugging -if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then - ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) -fi - -"${ARGS[@]}" diff --git a/apps/sgx/.rustfmt.toml b/apps/sgx/.rustfmt.toml deleted file mode 100644 index 3c51bb384c68..000000000000 --- a/apps/sgx/.rustfmt.toml +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -max_width = 100 -hard_tabs = false -tab_spaces = 4 -newline_style = "Auto" -use_small_heuristics = "Default" -reorder_imports = true -reorder_modules = true -remove_nested_parens = true -fn_args_layout = "Tall" -edition = "2018" -merge_derives = true -use_try_shorthand = false -use_field_init_shorthand = false -force_explicit_abi = true From f763d52ca8cb4b5d925dff80c615d32dddc2988f Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Thu, 5 Jan 2023 10:28:39 +0000 Subject: [PATCH 18/21] Revert "restore schedule ruls specially for ms dense_vnni" This reverts commit 2bda03e0ed67d86a90511ee9eb7afaa2215bad17. --- python/tvm/topi/x86/dense.py | 661 ----------------------------------- 1 file changed, 661 deletions(-) delete mode 100644 python/tvm/topi/x86/dense.py diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py deleted file mode 100644 index ada19d598cdf..000000000000 --- a/python/tvm/topi/x86/dense.py +++ /dev/null @@ -1,661 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name,too-many-locals,unused-argument -# pylint: disable=no-value-for-parameter,unused-variable -"""x86 dense operators""" -from __future__ import absolute_import as _abs - -import tvm -from tvm import autotvm, te -from tvm.autotvm.task.space import SplitEntity -from tvm.contrib import cblas, dnnl, mkl - -from .. import generic, tag -from ..utils import get_const_tuple, traverse_inline -from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake -from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids -from .tensor_intrin import acc_32x32_int32_sapphirerapids -from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx - - -def _schedule_dense_pack_template(cfg, s, C, O): - A, packedB = s[C].op.input_tensors - - CC = s.cache_write(C, "global") - y, x = s[C].op.axis - (k,) = s[CC].op.reduce_axis - - yt, yo, yi = cfg["tile_y"].apply(s, C, y) - xt, xo, xi = cfg["tile_x"].apply(s, C, x) - s[C].reorder(xt, yt, yo, xo, yi, xi) - xyt = s[C].fuse(xt, yt) - if C == O: - s[C].parallel(xyt) - xyo = s[C].fuse(yo, xo) - s[C].unroll(yi) - s[C].vectorize(xi) - - s[CC].compute_at(s[C], xyo) - y, x = s[CC].op.axis - ko, ki = cfg["tile_k"].apply(s, CC, k) - s[CC].reorder(ko, ki, y, x) - s[CC].vectorize(x) - - tile_inner = cfg["tile_inner"].size[-1] - if tile_inner > 1: - yo, yi = s[CC].split(y, tile_inner) - s[CC].reorder(ko, yo, ki, yi, x) - s[CC].unroll(yo) - s[CC].unroll(ki) - s[CC].unroll(yi) - else: - s[CC].unroll(ki) - s[CC].unroll(y) - - if C != O: - y, x = s[O].op.axis - yt, yo, yi = cfg["tile_y"].apply(s, O, y) - xt, xo, xi = cfg["tile_x"].apply(s, O, x) - s[O].reorder(xt, yt, yo, xo, yi, xi) - xyt = s[O].fuse(xt, yt) - s[C].compute_at(s[O], xyt) - s[O].vectorize(xi) - s[O].parallel(xyt) - return s - - -def _schedule_dense_nopack_template(cfg, s, C): - y, x = s[C].op.axis - (kk,) = s[C].op.reduce_axis - yo, yi = cfg["tile_y"].apply(s, C, y) - xo, xi = cfg["tile_x"].apply(s, C, x) - s[C].reorder(yo, xo, yi, xi) - xyo = s[C].fuse(yo, xo) - s[C].parallel(xyo) - s[C].unroll(kk) - - (CC,) = s[C].op.input_tensors - s[CC].compute_at(s[C], xyo) - z, y, x = s[CC].op.axis - (k,) = s[CC].op.reduce_axis - yz = s[CC].fuse(z, y) - s[CC].reorder(k, yz, x) - s[CC].unroll(yz) - s[CC].vectorize(x) - return s - - -def _default_dense_pack_config(cfg, M, N, K): - # Generate default schedule for dynamic shape. - if isinstance(M, (tvm.tir.Var, tvm.tir.Any)): - M = 16 - if isinstance(N, (tvm.tir.Var, tvm.tir.Any)): - N = 16 - if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): - K = 16 - - vec_width = get_simd_32bit_lanes() - tilex_ii = 1 - for bn in range(vec_width * 2, 0, -1): - if N % bn == 0: - tilex_ii = bn - break - NN = N // tilex_ii - tilex_oi = 1 - while NN // tilex_oi > 4: - if (NN // tilex_oi) % 2 == 1: - break - tilex_oi *= 2 - - tiley_ii = 8 - while M % tiley_ii != 0: - tiley_ii //= 2 - MM = M // tiley_ii - tiley_oi = 1 - while MM // tiley_oi > 4: - if (MM // tiley_oi) % 2 == 1: - break - tiley_oi *= 2 - - cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii]) - cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii]) - cfg["tile_k"] = SplitEntity([K, 1]) - cfg["tile_inner"] = SplitEntity([M // tiley_ii, tiley_ii]) - - -def _default_dense_nopack_config(cfg, M, N, K): - # Generate default schedule for dynamic shape. - if isinstance(M, (tvm.tir.Var, tvm.tir.Any)): - M = 16 - if isinstance(N, (tvm.tir.Var, tvm.tir.Any)): - N = 16 - if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): - K = 16 - - vec_width = get_simd_32bit_lanes() - tilek_bn = 1 - for bn in range(vec_width * 2, 0, -1): - if K % bn == 0: - tilek_bn = bn - break - cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn]) - cfg["tile_x"] = SplitEntity([N, 1]) - cfg["tile_y"] = SplitEntity([1, M]) - - -@autotvm.register_topi_compute("dense_nopack.x86") -def dense_nopack(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense without packing""" - if out_dtype is None: - out_dtype = data.dtype - M, K = get_const_tuple(data.shape) - N, _ = get_const_tuple(weight.shape) - # create tuning space - cfg.define_split( - "tile_y", 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, num_outputs=2 - ) - cfg.define_split( - "tile_x", 32 if isinstance(N, (tvm.tir.Var, tvm.tir.Any)) else N, num_outputs=2 - ) - cfg.define_split( - "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 - ) - if cfg.is_fallback: - _default_dense_nopack_config(cfg, M, N, K) - - vec = cfg["tile_k"].size[-1] - k = te.reduce_axis((0, K // vec), "k") - CC = te.compute( - (M, N, vec), - lambda z, y, x: te.sum( - data[z, k * vec + x].astype(out_dtype) * weight[y, k * vec + x].astype(out_dtype), - axis=k, - ), - ) - - kk = te.reduce_axis((0, vec), "kk") - C = te.compute((M, N), lambda y, x: te.sum(CC[y, x, kk], axis=kk), tag="dense_nopack") - if bias is not None: - C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) - return C - - -@autotvm.register_topi_schedule("dense_nopack.x86") -def schedule_dense_nopack(cfg, outs): - """Create the schedule for dense_nopack""" - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if "dense_nopack" in op.tag: - _schedule_dense_nopack_template(cfg, s, op.output(0)) - - traverse_inline(s, outs[0].op, _callback) - return s - - -@autotvm.register_topi_compute("dense_pack.x86") -def dense_pack(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense with transformed weight.""" - if out_dtype is None: - out_dtype = data.dtype - M, K = get_const_tuple(data.shape) # batch, in_dim - if len(weight.shape) == 3: - N, _, packw_bn = get_const_tuple(weight.shape) # out_dim - N = N * packw_bn - else: - N, _ = get_const_tuple(weight.shape) # out_dim - # create tuning space - cfg.define_split( - "tile_y", 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, num_outputs=3 - ) - cfg.define_split( - "tile_x", 32 if isinstance(N, (tvm.tir.Var, tvm.tir.Any)) else N, num_outputs=3 - ) - cfg.define_split( - "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 - ) - cfg.define_split( - "tile_inner", - 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, - num_outputs=2, - filter=lambda y: y.size[-1] <= 16, - ) - if cfg.is_fallback: - _default_dense_pack_config(cfg, M, N, K) - - if len(weight.shape) == 2: - packw_bn = cfg["tile_x"].size[-1] - packw_shape = (N // packw_bn, K, packw_bn) - if autotvm.GLOBAL_SCOPE.in_tuning: - # Directly use modified data layout placeholder. - packw = tvm.te.placeholder(packw_shape, weight.dtype, name="packed_weight") - else: - packw = te.compute( - packw_shape, lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight" - ) - else: - packw = weight - - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda y, x: te.sum( - data[y, k].astype(out_dtype) - * packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype), - axis=k, - ), - tag="dense_pack", - ) - if bias is not None: - C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) - return C - - -@autotvm.register_topi_schedule("dense_pack.x86") -def schedule_dense_pack(cfg, outs): - """Create the schedule for dense_pack""" - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if "dense_pack" in op.tag: - _schedule_dense_pack_template(cfg, s, op.output(0), outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s - - -@autotvm.register_topi_compute("dense_int8.x86") -def dense_int8(cfg, data, weight, bias=None, out_dtype=None): - """Compute for uint8 x int8 -> int32 dense""" - if out_dtype is None: - out_dtype = data.dtype - assert len(weight.shape) == 4 - assert data.dtype == "uint8" and weight.dtype == "int8" - _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim - assert n_inner == 16 and k_inner == 4 - return dense_int8_compute(cfg, data, weight, bias) - - -@autotvm.register_topi_schedule("dense_int8.x86") -def schedule_dense_int8(cfg, outs): - """Create a schedule for dense__int8""" - s = te.create_schedule([x.op for x in outs]) - mcpu = tvm.target.Target.current().mcpu - - def _callback(op): - if "dense_int8" in op.tag: - if target_has_amx(mcpu): - dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - elif target_has_vnni(mcpu): - dense_vnni_schedule(cfg, s, op.output(0), outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s - - -def dense_int8_compute(cfg, X, packed_w, bias=None): - """Compute for uint8 x int8 -> int32 dense""" - m, k = X.shape - n_o, _, n_i, _ = packed_w.shape - ak = te.reduce_axis((0, k), name="k") - mcpu = tvm.target.Target.current().mcpu - if target_has_vnni(mcpu): - target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"} - else: - target_attr = None - - C = te.compute( - (m, n_o * n_i), - lambda i, j: te.sum( - X[i, ak].astype("int32") - * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( - "int32" - ), - axis=ak, - ), - tag="dense_int8", - attrs=target_attr, - ) - - if bias is not None: - C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) - - return C - - -def dense_vnni_schedule(cfg, s, C, O, do_parallel=True): - """Schedule dense compute using VNNI vpdpbusd instruction""" - # C: The output of GEMM - # O: The output of the fused op - def split_y(out): - default_y_split_factor = 32 - a_y = out.op.axis[-2] - - if cfg.is_fallback: - return s[out].split(a_y, factor=default_y_split_factor) - - cfg.define_split("tile_y", a_y, num_outputs=2) - return cfg["tile_y"].apply(s, out, a_y) - - (a_k,) = C.op.reduce_axis - - a_yo, a_yi = split_y(C) - a_xo, a_xi = s[C].split(C.op.axis[-1], factor=16) - a_ko, a_ki = s[C].split(a_k, factor=4) - - s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki) - - pc = dot_16x1x16_uint8_int8_int32_cascadelake() - s[C].tensorize(a_xi, pc) - - if C == O: - fused = s[O].fuse(a_yo, a_xo) - else: - a_yo, a_yi = split_y(O) - a_xo, a_xi = s[O].split(O.op.axis[-1], factor=16) - - s[O].reorder(a_yo, a_xo, a_yi, a_xi) - s[O].vectorize(a_xi) - s[C].compute_at(s[O], a_yi) - - fused = s[O].fuse(a_yo, a_xo) - - if do_parallel: - s[O].parallel(fused) - - return s, fused - - -def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True): - """Schedule dense compute using AMX TMUL instruction""" - # C: The output of GEMM - # O: The output of the fused op - def split_x(out): - default_x_split_factor1 = 32 - default_x_split_factor2 = 2 - default_x_split_factor3 = 2 - default_x_split_factor4 = 2 - a_x = s[out].op.axis[-2] - - if cfg.is_fallback: - a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1) - a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2) - a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3) - a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4) - return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi] - - cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: x.size[-1] == 32) - return cfg["tile_x"].apply(s, out, a_x) - - def split_y(out): - default_y_split_factor1 = 32 - default_y_split_factor2 = 4 - default_y_split_factor3 = 4 - default_y_split_factor4 = 4 - a_y = s[out].op.axis[-1] - - if cfg.is_fallback: - a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1) - a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2) - a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3) - a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4) - return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo] - - cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: y.size[-1] == 32) - return cfg["tile_y"].apply(s, out, a_y) - - def split_k(out, rd_axis): - default_k_split_factor1 = 128 - default_k_split_factor2 = 2 - default_k_split_factor3 = 2 - default_k_split_factor4 = 2 - - if cfg.is_fallback: - a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1) - a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2) - a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3) - a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4) - return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki] - - cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128) - return cfg["tile_k"].apply(s, out, rd_axis) - - a_x, a_y = C.op.axis - (a_k,) = C.op.reduce_axis - CF = s.cache_write(C, "amx.tmm") - - a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C) - a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C) - s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi) - - s[CF].compute_at(s[C], a_yo) - - (a_k_f,) = CF.op.reduce_axis - a_x_f, a_y_f = CF.op.axis - - a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) - - a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) - a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f) - s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) - - (m, k) = CF.op.input_tensors[0].shape - (n, c, n_i, c_i) = CF.op.input_tensors[1].shape - n = n * n_i - - s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k))) - s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n))) - - if C == O: - fused = s[O].fuse(a_x3, a_y3) - else: - a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O) - a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O) - - s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi) - s[O].vectorize(a_xi) - - fused = s[O].fuse(a_x3, a_y3) - - if do_parallel: - s[O].parallel(fused) - - return s, fused - - -@autotvm.register_topi_schedule("dense_amx_int8.x86") -def schedule_dense_amx_int8(cfg, outs): - """Create a schedule for dense_amx_int8""" - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if "dense_amx_int8" in op.tag: - dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s - - -def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib): - """Compute matmul/dense using a BLAS library""" - M, K = get_const_tuple(tensor_a.shape) - N, _ = get_const_tuple(tensor_b.shape) - if isinstance(M, int) and isinstance(K, int) and isinstance(N, int): - cfg.add_flop(M * K * N * 2) - if tensor_a.dtype == "uint8" and tensor_b.dtype == "int8" and out_dtype == "int32": - if not hasattr(lib, "matmul_u8s8s32"): - raise NotImplementedError( - f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not supported " - "(matmulu8s8s32 not imlemented)" - ) - C = lib.matmul_u8s8s32(tensor_a, tensor_b, transpose_a, transpose_b, dtype=out_dtype) - elif tensor_a.dtype == "float32" or tensor_a.dtype == "float64": - C = lib.matmul(tensor_a, tensor_b, transpose_a, transpose_b) - else: - raise NotImplementedError( - f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not supported" - ) - - if bias is not None: - C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) - return C - - -@autotvm.register_topi_compute("dense_cblas.x86") -def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense using cblas. This is an alias of matmul_nt operator.""" - return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, cblas) - - -@autotvm.register_topi_schedule("dense_cblas.x86") -def schedule_dense_cblas(_, outs): - """Create schedule for dense_cblas. This is an alias of matmul_nt operator.""" - return generic.schedule_extern(outs) - - -@autotvm.register_topi_compute("dense_mkl.x86") -def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense using mkl. This is an alias of matmul_nt operator.""" - return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, mkl) - - -@autotvm.register_topi_schedule("dense_mkl.x86") -def schedule_dense_mkl(_, outs): - """Create schedule for dense_mkl. This is an alias of matmul_nt operator.""" - return generic.schedule_extern(outs) - - -@autotvm.register_topi_compute("dense_dnnl.x86") -def dense_dnnl(cfg, data, weight, bias=None, out_dtype=None): - """Compute dense using dnnl. This is an alias of matmul_nt operator.""" - return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, dnnl) - - -@autotvm.register_topi_schedule("dense_dnnl.x86") -def schedule_dense_dnnl(_, outs): - """Create schedule for dense_dnnl. This is an alias of matmul_nt operator.""" - return generic.schedule_extern(outs) - - -@autotvm.register_topi_compute("matmul_cblas.x86") -def matmul_cblas( - cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False -): - """Compute matmul using cblas.""" - return matmul_blas_common( - cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, cblas - ) - - -@autotvm.register_topi_schedule("matmul_cblas.x86") -def schedule_matmul_cblas(_, outs): - """Create schedule for matmul_cblas.""" - return generic.schedule_extern(outs) - - -@autotvm.register_topi_compute("matmul_mkl.x86") -def matmul_mkl( - cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False -): - """Compute matmul using mkl.""" - return matmul_blas_common( - cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, mkl - ) - - -@autotvm.register_topi_schedule("matmul_mkl.x86") -def schedule_matmul_mkl(_, outs): - """Create schedule for matmul_mkl.""" - return generic.schedule_extern(outs) - - -@autotvm.register_topi_compute("matmul_dnnl.x86") -def matmul_dnnl( - cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False -): - """Compute matmul using dnnl.""" - return matmul_blas_common( - cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, dnnl - ) - - -@autotvm.register_topi_schedule("matmul_dnnl.x86") -def schedule_matmul_dnnl(_, outs): - """Create schedule for matmul_dnnl.""" - return generic.schedule_extern(outs) - - -def dense_dynamic(A, B, bias, dtype): - """Compute for dense with dynamic shape""" - - assert A.shape[0] == 1, "Only dynamic matrix vector multiplication with vector LHS is supported" - - # Right now we only support matrix-vector multiplication with lhs as the - # vector. We don't need to do much optimization here because the access - # pattern and parallelization are straight forward. - def gen_ir(a, b, c): - ib = tvm.tir.ir_builder.create() - A = ib.buffer_ptr(a) - B = ib.buffer_ptr(b) - C = ib.buffer_ptr(c) - with ib.for_range(0, b.shape[0], name="j", kind="parallel") as j: - C[0, j] = 0.0 - with ib.for_range(0, b.shape[1], name="k") as k: - C[0, j] += A[0, k] * B[j, k] - return ib.get() - - def gen_ir_bias(a, b, bias, c): - ib = tvm.tir.ir_builder.create() - A = ib.buffer_ptr(a) - B = ib.buffer_ptr(b) - C = ib.buffer_ptr(c) - with ib.for_range(0, b.shape[0], name="j", kind="parallel") as j: - C[0, j] = bias[j] - with ib.for_range(0, b.shape[1], name="k") as k: - C[0, j] += A[0, k] * B[j, k] - return ib.get() - - out_shape = (A.shape[0], B.shape[0]) - out_buf = tvm.tir.decl_buffer(out_shape, dtype, "out_buf") - if bias is None: - out = te.extern( - [out_shape], - [A, B], - lambda ins, outs: gen_ir(*ins, *outs), - dtype=dtype, - out_buffers=[out_buf], - name="dense_dynamic_cpu", - tag="dense_dynamic_cpu", - ) - else: - out = te.extern( - [out_shape], - [A, B, bias], - lambda ins, outs: gen_ir_bias(*ins, *outs), - dtype=dtype, - out_buffers=[out_buf], - name="dense_dynamic_cpu", - tag="dense_dynamic_cpu", - ) - return out - - -def schedule_dense_dynamic(outs): - """Create schedule for dense_dynamic.""" - return generic.schedule_extern(outs) From 1f59aff6421bfca7289f0c052923c576b4944407 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Thu, 5 Jan 2023 10:31:08 +0000 Subject: [PATCH 19/21] add vnni ms target attributes and remove misops --- python/tvm/topi/x86/dense.py | 661 +++++++++++++++++++++++++++++++++++ 1 file changed, 661 insertions(+) create mode 100644 python/tvm/topi/x86/dense.py diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py new file mode 100644 index 000000000000..ada19d598cdf --- /dev/null +++ b/python/tvm/topi/x86/dense.py @@ -0,0 +1,661 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,too-many-locals,unused-argument +# pylint: disable=no-value-for-parameter,unused-variable +"""x86 dense operators""" +from __future__ import absolute_import as _abs + +import tvm +from tvm import autotvm, te +from tvm.autotvm.task.space import SplitEntity +from tvm.contrib import cblas, dnnl, mkl + +from .. import generic, tag +from ..utils import get_const_tuple, traverse_inline +from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake +from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids +from .tensor_intrin import acc_32x32_int32_sapphirerapids +from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx + + +def _schedule_dense_pack_template(cfg, s, C, O): + A, packedB = s[C].op.input_tensors + + CC = s.cache_write(C, "global") + y, x = s[C].op.axis + (k,) = s[CC].op.reduce_axis + + yt, yo, yi = cfg["tile_y"].apply(s, C, y) + xt, xo, xi = cfg["tile_x"].apply(s, C, x) + s[C].reorder(xt, yt, yo, xo, yi, xi) + xyt = s[C].fuse(xt, yt) + if C == O: + s[C].parallel(xyt) + xyo = s[C].fuse(yo, xo) + s[C].unroll(yi) + s[C].vectorize(xi) + + s[CC].compute_at(s[C], xyo) + y, x = s[CC].op.axis + ko, ki = cfg["tile_k"].apply(s, CC, k) + s[CC].reorder(ko, ki, y, x) + s[CC].vectorize(x) + + tile_inner = cfg["tile_inner"].size[-1] + if tile_inner > 1: + yo, yi = s[CC].split(y, tile_inner) + s[CC].reorder(ko, yo, ki, yi, x) + s[CC].unroll(yo) + s[CC].unroll(ki) + s[CC].unroll(yi) + else: + s[CC].unroll(ki) + s[CC].unroll(y) + + if C != O: + y, x = s[O].op.axis + yt, yo, yi = cfg["tile_y"].apply(s, O, y) + xt, xo, xi = cfg["tile_x"].apply(s, O, x) + s[O].reorder(xt, yt, yo, xo, yi, xi) + xyt = s[O].fuse(xt, yt) + s[C].compute_at(s[O], xyt) + s[O].vectorize(xi) + s[O].parallel(xyt) + return s + + +def _schedule_dense_nopack_template(cfg, s, C): + y, x = s[C].op.axis + (kk,) = s[C].op.reduce_axis + yo, yi = cfg["tile_y"].apply(s, C, y) + xo, xi = cfg["tile_x"].apply(s, C, x) + s[C].reorder(yo, xo, yi, xi) + xyo = s[C].fuse(yo, xo) + s[C].parallel(xyo) + s[C].unroll(kk) + + (CC,) = s[C].op.input_tensors + s[CC].compute_at(s[C], xyo) + z, y, x = s[CC].op.axis + (k,) = s[CC].op.reduce_axis + yz = s[CC].fuse(z, y) + s[CC].reorder(k, yz, x) + s[CC].unroll(yz) + s[CC].vectorize(x) + return s + + +def _default_dense_pack_config(cfg, M, N, K): + # Generate default schedule for dynamic shape. + if isinstance(M, (tvm.tir.Var, tvm.tir.Any)): + M = 16 + if isinstance(N, (tvm.tir.Var, tvm.tir.Any)): + N = 16 + if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): + K = 16 + + vec_width = get_simd_32bit_lanes() + tilex_ii = 1 + for bn in range(vec_width * 2, 0, -1): + if N % bn == 0: + tilex_ii = bn + break + NN = N // tilex_ii + tilex_oi = 1 + while NN // tilex_oi > 4: + if (NN // tilex_oi) % 2 == 1: + break + tilex_oi *= 2 + + tiley_ii = 8 + while M % tiley_ii != 0: + tiley_ii //= 2 + MM = M // tiley_ii + tiley_oi = 1 + while MM // tiley_oi > 4: + if (MM // tiley_oi) % 2 == 1: + break + tiley_oi *= 2 + + cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii]) + cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii]) + cfg["tile_k"] = SplitEntity([K, 1]) + cfg["tile_inner"] = SplitEntity([M // tiley_ii, tiley_ii]) + + +def _default_dense_nopack_config(cfg, M, N, K): + # Generate default schedule for dynamic shape. + if isinstance(M, (tvm.tir.Var, tvm.tir.Any)): + M = 16 + if isinstance(N, (tvm.tir.Var, tvm.tir.Any)): + N = 16 + if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): + K = 16 + + vec_width = get_simd_32bit_lanes() + tilek_bn = 1 + for bn in range(vec_width * 2, 0, -1): + if K % bn == 0: + tilek_bn = bn + break + cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn]) + cfg["tile_x"] = SplitEntity([N, 1]) + cfg["tile_y"] = SplitEntity([1, M]) + + +@autotvm.register_topi_compute("dense_nopack.x86") +def dense_nopack(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense without packing""" + if out_dtype is None: + out_dtype = data.dtype + M, K = get_const_tuple(data.shape) + N, _ = get_const_tuple(weight.shape) + # create tuning space + cfg.define_split( + "tile_y", 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, num_outputs=2 + ) + cfg.define_split( + "tile_x", 32 if isinstance(N, (tvm.tir.Var, tvm.tir.Any)) else N, num_outputs=2 + ) + cfg.define_split( + "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 + ) + if cfg.is_fallback: + _default_dense_nopack_config(cfg, M, N, K) + + vec = cfg["tile_k"].size[-1] + k = te.reduce_axis((0, K // vec), "k") + CC = te.compute( + (M, N, vec), + lambda z, y, x: te.sum( + data[z, k * vec + x].astype(out_dtype) * weight[y, k * vec + x].astype(out_dtype), + axis=k, + ), + ) + + kk = te.reduce_axis((0, vec), "kk") + C = te.compute((M, N), lambda y, x: te.sum(CC[y, x, kk], axis=kk), tag="dense_nopack") + if bias is not None: + C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) + return C + + +@autotvm.register_topi_schedule("dense_nopack.x86") +def schedule_dense_nopack(cfg, outs): + """Create the schedule for dense_nopack""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "dense_nopack" in op.tag: + _schedule_dense_nopack_template(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("dense_pack.x86") +def dense_pack(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense with transformed weight.""" + if out_dtype is None: + out_dtype = data.dtype + M, K = get_const_tuple(data.shape) # batch, in_dim + if len(weight.shape) == 3: + N, _, packw_bn = get_const_tuple(weight.shape) # out_dim + N = N * packw_bn + else: + N, _ = get_const_tuple(weight.shape) # out_dim + # create tuning space + cfg.define_split( + "tile_y", 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, num_outputs=3 + ) + cfg.define_split( + "tile_x", 32 if isinstance(N, (tvm.tir.Var, tvm.tir.Any)) else N, num_outputs=3 + ) + cfg.define_split( + "tile_k", 32 if isinstance(K, (tvm.tir.Var, tvm.tir.Any)) else K, num_outputs=2 + ) + cfg.define_split( + "tile_inner", + 32 if isinstance(M, (tvm.tir.Var, tvm.tir.Any)) else M, + num_outputs=2, + filter=lambda y: y.size[-1] <= 16, + ) + if cfg.is_fallback: + _default_dense_pack_config(cfg, M, N, K) + + if len(weight.shape) == 2: + packw_bn = cfg["tile_x"].size[-1] + packw_shape = (N // packw_bn, K, packw_bn) + if autotvm.GLOBAL_SCOPE.in_tuning: + # Directly use modified data layout placeholder. + packw = tvm.te.placeholder(packw_shape, weight.dtype, name="packed_weight") + else: + packw = te.compute( + packw_shape, lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight" + ) + else: + packw = weight + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda y, x: te.sum( + data[y, k].astype(out_dtype) + * packw[idxdiv(x, packw_bn), k, idxmod(x, packw_bn)].astype(out_dtype), + axis=k, + ), + tag="dense_pack", + ) + if bias is not None: + C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) + return C + + +@autotvm.register_topi_schedule("dense_pack.x86") +def schedule_dense_pack(cfg, outs): + """Create the schedule for dense_pack""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "dense_pack" in op.tag: + _schedule_dense_pack_template(cfg, s, op.output(0), outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("dense_int8.x86") +def dense_int8(cfg, data, weight, bias=None, out_dtype=None): + """Compute for uint8 x int8 -> int32 dense""" + if out_dtype is None: + out_dtype = data.dtype + assert len(weight.shape) == 4 + assert data.dtype == "uint8" and weight.dtype == "int8" + _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim + assert n_inner == 16 and k_inner == 4 + return dense_int8_compute(cfg, data, weight, bias) + + +@autotvm.register_topi_schedule("dense_int8.x86") +def schedule_dense_int8(cfg, outs): + """Create a schedule for dense__int8""" + s = te.create_schedule([x.op for x in outs]) + mcpu = tvm.target.Target.current().mcpu + + def _callback(op): + if "dense_int8" in op.tag: + if target_has_amx(mcpu): + dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) + elif target_has_vnni(mcpu): + dense_vnni_schedule(cfg, s, op.output(0), outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def dense_int8_compute(cfg, X, packed_w, bias=None): + """Compute for uint8 x int8 -> int32 dense""" + m, k = X.shape + n_o, _, n_i, _ = packed_w.shape + ak = te.reduce_axis((0, k), name="k") + mcpu = tvm.target.Target.current().mcpu + if target_has_vnni(mcpu): + target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"} + else: + target_attr = None + + C = te.compute( + (m, n_o * n_i), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( + "int32" + ), + axis=ak, + ), + tag="dense_int8", + attrs=target_attr, + ) + + if bias is not None: + C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) + + return C + + +def dense_vnni_schedule(cfg, s, C, O, do_parallel=True): + """Schedule dense compute using VNNI vpdpbusd instruction""" + # C: The output of GEMM + # O: The output of the fused op + def split_y(out): + default_y_split_factor = 32 + a_y = out.op.axis[-2] + + if cfg.is_fallback: + return s[out].split(a_y, factor=default_y_split_factor) + + cfg.define_split("tile_y", a_y, num_outputs=2) + return cfg["tile_y"].apply(s, out, a_y) + + (a_k,) = C.op.reduce_axis + + a_yo, a_yi = split_y(C) + a_xo, a_xi = s[C].split(C.op.axis[-1], factor=16) + a_ko, a_ki = s[C].split(a_k, factor=4) + + s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki) + + pc = dot_16x1x16_uint8_int8_int32_cascadelake() + s[C].tensorize(a_xi, pc) + + if C == O: + fused = s[O].fuse(a_yo, a_xo) + else: + a_yo, a_yi = split_y(O) + a_xo, a_xi = s[O].split(O.op.axis[-1], factor=16) + + s[O].reorder(a_yo, a_xo, a_yi, a_xi) + s[O].vectorize(a_xi) + s[C].compute_at(s[O], a_yi) + + fused = s[O].fuse(a_yo, a_xo) + + if do_parallel: + s[O].parallel(fused) + + return s, fused + + +def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True): + """Schedule dense compute using AMX TMUL instruction""" + # C: The output of GEMM + # O: The output of the fused op + def split_x(out): + default_x_split_factor1 = 32 + default_x_split_factor2 = 2 + default_x_split_factor3 = 2 + default_x_split_factor4 = 2 + a_x = s[out].op.axis[-2] + + if cfg.is_fallback: + a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1) + a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2) + a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3) + a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4) + return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi] + + cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: x.size[-1] == 32) + return cfg["tile_x"].apply(s, out, a_x) + + def split_y(out): + default_y_split_factor1 = 32 + default_y_split_factor2 = 4 + default_y_split_factor3 = 4 + default_y_split_factor4 = 4 + a_y = s[out].op.axis[-1] + + if cfg.is_fallback: + a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1) + a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2) + a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3) + a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4) + return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo] + + cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: y.size[-1] == 32) + return cfg["tile_y"].apply(s, out, a_y) + + def split_k(out, rd_axis): + default_k_split_factor1 = 128 + default_k_split_factor2 = 2 + default_k_split_factor3 = 2 + default_k_split_factor4 = 2 + + if cfg.is_fallback: + a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1) + a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2) + a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3) + a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4) + return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki] + + cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128) + return cfg["tile_k"].apply(s, out, rd_axis) + + a_x, a_y = C.op.axis + (a_k,) = C.op.reduce_axis + CF = s.cache_write(C, "amx.tmm") + + a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C) + a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C) + s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi) + + s[CF].compute_at(s[C], a_yo) + + (a_k_f,) = CF.op.reduce_axis + a_x_f, a_y_f = CF.op.axis + + a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) + + a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) + a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f) + s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) + + (m, k) = CF.op.input_tensors[0].shape + (n, c, n_i, c_i) = CF.op.input_tensors[1].shape + n = n * n_i + + s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k))) + s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n))) + + if C == O: + fused = s[O].fuse(a_x3, a_y3) + else: + a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O) + a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O) + + s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi) + s[O].vectorize(a_xi) + + fused = s[O].fuse(a_x3, a_y3) + + if do_parallel: + s[O].parallel(fused) + + return s, fused + + +@autotvm.register_topi_schedule("dense_amx_int8.x86") +def schedule_dense_amx_int8(cfg, outs): + """Create a schedule for dense_amx_int8""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "dense_amx_int8" in op.tag: + dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib): + """Compute matmul/dense using a BLAS library""" + M, K = get_const_tuple(tensor_a.shape) + N, _ = get_const_tuple(tensor_b.shape) + if isinstance(M, int) and isinstance(K, int) and isinstance(N, int): + cfg.add_flop(M * K * N * 2) + if tensor_a.dtype == "uint8" and tensor_b.dtype == "int8" and out_dtype == "int32": + if not hasattr(lib, "matmul_u8s8s32"): + raise NotImplementedError( + f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not supported " + "(matmulu8s8s32 not imlemented)" + ) + C = lib.matmul_u8s8s32(tensor_a, tensor_b, transpose_a, transpose_b, dtype=out_dtype) + elif tensor_a.dtype == "float32" or tensor_a.dtype == "float64": + C = lib.matmul(tensor_a, tensor_b, transpose_a, transpose_b) + else: + raise NotImplementedError( + f"Matmul/Dense with {lib.__name__} for {tensor_a.dtype} is not supported" + ) + + if bias is not None: + C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) + return C + + +@autotvm.register_topi_compute("dense_cblas.x86") +def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense using cblas. This is an alias of matmul_nt operator.""" + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, cblas) + + +@autotvm.register_topi_schedule("dense_cblas.x86") +def schedule_dense_cblas(_, outs): + """Create schedule for dense_cblas. This is an alias of matmul_nt operator.""" + return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("dense_mkl.x86") +def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense using mkl. This is an alias of matmul_nt operator.""" + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, mkl) + + +@autotvm.register_topi_schedule("dense_mkl.x86") +def schedule_dense_mkl(_, outs): + """Create schedule for dense_mkl. This is an alias of matmul_nt operator.""" + return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("dense_dnnl.x86") +def dense_dnnl(cfg, data, weight, bias=None, out_dtype=None): + """Compute dense using dnnl. This is an alias of matmul_nt operator.""" + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, dnnl) + + +@autotvm.register_topi_schedule("dense_dnnl.x86") +def schedule_dense_dnnl(_, outs): + """Create schedule for dense_dnnl. This is an alias of matmul_nt operator.""" + return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("matmul_cblas.x86") +def matmul_cblas( + cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False +): + """Compute matmul using cblas.""" + return matmul_blas_common( + cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, cblas + ) + + +@autotvm.register_topi_schedule("matmul_cblas.x86") +def schedule_matmul_cblas(_, outs): + """Create schedule for matmul_cblas.""" + return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("matmul_mkl.x86") +def matmul_mkl( + cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False +): + """Compute matmul using mkl.""" + return matmul_blas_common( + cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, mkl + ) + + +@autotvm.register_topi_schedule("matmul_mkl.x86") +def schedule_matmul_mkl(_, outs): + """Create schedule for matmul_mkl.""" + return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("matmul_dnnl.x86") +def matmul_dnnl( + cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False +): + """Compute matmul using dnnl.""" + return matmul_blas_common( + cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, dnnl + ) + + +@autotvm.register_topi_schedule("matmul_dnnl.x86") +def schedule_matmul_dnnl(_, outs): + """Create schedule for matmul_dnnl.""" + return generic.schedule_extern(outs) + + +def dense_dynamic(A, B, bias, dtype): + """Compute for dense with dynamic shape""" + + assert A.shape[0] == 1, "Only dynamic matrix vector multiplication with vector LHS is supported" + + # Right now we only support matrix-vector multiplication with lhs as the + # vector. We don't need to do much optimization here because the access + # pattern and parallelization are straight forward. + def gen_ir(a, b, c): + ib = tvm.tir.ir_builder.create() + A = ib.buffer_ptr(a) + B = ib.buffer_ptr(b) + C = ib.buffer_ptr(c) + with ib.for_range(0, b.shape[0], name="j", kind="parallel") as j: + C[0, j] = 0.0 + with ib.for_range(0, b.shape[1], name="k") as k: + C[0, j] += A[0, k] * B[j, k] + return ib.get() + + def gen_ir_bias(a, b, bias, c): + ib = tvm.tir.ir_builder.create() + A = ib.buffer_ptr(a) + B = ib.buffer_ptr(b) + C = ib.buffer_ptr(c) + with ib.for_range(0, b.shape[0], name="j", kind="parallel") as j: + C[0, j] = bias[j] + with ib.for_range(0, b.shape[1], name="k") as k: + C[0, j] += A[0, k] * B[j, k] + return ib.get() + + out_shape = (A.shape[0], B.shape[0]) + out_buf = tvm.tir.decl_buffer(out_shape, dtype, "out_buf") + if bias is None: + out = te.extern( + [out_shape], + [A, B], + lambda ins, outs: gen_ir(*ins, *outs), + dtype=dtype, + out_buffers=[out_buf], + name="dense_dynamic_cpu", + tag="dense_dynamic_cpu", + ) + else: + out = te.extern( + [out_shape], + [A, B, bias], + lambda ins, outs: gen_ir_bias(*ins, *outs), + dtype=dtype, + out_buffers=[out_buf], + name="dense_dynamic_cpu", + tag="dense_dynamic_cpu", + ) + return out + + +def schedule_dense_dynamic(outs): + """Create schedule for dense_dynamic.""" + return generic.schedule_extern(outs) From 383d0b2899c3e7e9399f5b4a01a61a6fb6f4796d Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Thu, 5 Jan 2023 10:58:01 +0000 Subject: [PATCH 20/21] Revert "add vnni ms target attributes" This reverts commit c2e9f26fd9d84ce75e9a0c1474df1b7e0b9ff4f3. --- .../zephyr/template_project/qemu-hack/qemu-system-arm | 1 + .../zephyr/template_project/qemu-hack/qemu-system-riscv32 | 1 + .../zephyr/template_project/qemu-hack/qemu-system-riscv64 | 1 + .../template_project/qemu-hack/qemu-system-xilinx-aarch64 | 1 + apps/sgx/.rustfmt.toml | 1 + python/tvm/topi/x86/dense.py | 7 +------ 6 files changed, 6 insertions(+), 6 deletions(-) create mode 120000 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm create mode 120000 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 create mode 120000 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 create mode 120000 apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 create mode 120000 apps/sgx/.rustfmt.toml diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm new file mode 120000 index 000000000000..ebbc8ad5ad9d --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm @@ -0,0 +1 @@ +qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 new file mode 120000 index 000000000000..ebbc8ad5ad9d --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 @@ -0,0 +1 @@ +qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 new file mode 120000 index 000000000000..ebbc8ad5ad9d --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 @@ -0,0 +1 @@ +qemu-system-i386 \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 new file mode 120000 index 000000000000..ebbc8ad5ad9d --- /dev/null +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 @@ -0,0 +1 @@ +qemu-system-i386 \ No newline at end of file diff --git a/apps/sgx/.rustfmt.toml b/apps/sgx/.rustfmt.toml new file mode 120000 index 000000000000..27139e42a3f2 --- /dev/null +++ b/apps/sgx/.rustfmt.toml @@ -0,0 +1 @@ +../../rust/.rustfmt.toml \ No newline at end of file diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index ada19d598cdf..5f5990b32b26 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -314,11 +314,6 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") - mcpu = tvm.target.Target.current().mcpu - if target_has_vnni(mcpu): - target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"} - else: - target_attr = None C = te.compute( (m, n_o * n_i), @@ -330,7 +325,7 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_int8", - attrs=target_attr, + attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"}, ) if bias is not None: From 94223634b29f66357753e7a3cd7eebe68820c6c1 Mon Sep 17 00:00:00 2001 From: "Jiang, Qianshui" Date: Thu, 5 Jan 2023 11:00:22 +0000 Subject: [PATCH 21/21] remove the misops --- python/tvm/topi/x86/dense.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 5f5990b32b26..ada19d598cdf 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -314,6 +314,11 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") + mcpu = tvm.target.Target.current().mcpu + if target_has_vnni(mcpu): + target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"} + else: + target_attr = None C = te.compute( (m, n_o * n_i), @@ -325,7 +330,7 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_int8", - attrs={"schedule_rule": "meta_schedule.x86.dense_vnni"}, + attrs=target_attr, ) if bias is not None: