diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index c3e43f897c509..473ec51524f22 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -33,7 +33,6 @@ constexpr size_t A = 0, }; typedef enum { - Level0, /*!< input fp32, accumulator fp32 */ Level1, /*!< input fp32, accumulator fp32 */ Level2, /*!< input fp16, accumulator fp16 */ Level3, /*!< input bf16, accumulator fp32 */ diff --git a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp index d7d99899d544a..f1bc013a469d9 100644 --- a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -23,7 +23,7 @@ Module Name: #include #include "fp16_common.h" -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon @@ -131,7 +131,7 @@ HQ4BitGemmPackQuantBData_CompFp16( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 100d7d47751aa..56c3b7b541037 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1049,13 +1049,6 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; -// -// Rotary embedding dispatch structure. -// -struct MLAS_ROPE_DISPATCH; -extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; - - // // Quantized depthwise convolution kernels. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ec572a4150292..06a96df571fcc 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -573,6 +573,9 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; + + // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index f064a8e1d6a78..3ffa8238fb9ff 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -402,7 +402,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); @@ -808,7 +808,7 @@ GetQNBitGemm(QNBitGemmVariant variant) { switch (variant) { case HQNBitGemmVariant_BitWidth4_CompFp16: - return HQ4BitGemm_CompFp16; + return nullptr; default: return nullptr; } diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index eb3d0b44ae3de..5fc93e0ab52ef 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -91,17 +91,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH { // /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ - typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)( + typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( size_t N, size_t K, size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + SQ4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ - typedef void(Q4BitGemmPackQuantBData_Fn)( + typedef void(SQ4BitGemmPackQuantBData_Fn)( size_t N, size_t K, size_t BlkLen, @@ -151,7 +151,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; + SQ4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -164,7 +164,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH { MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; + SQ4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; // // SQNBIT_CompFp32 kernel function prototypes. diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index d05de64e68ec8..b3dc0f44f262f 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -20,7 +20,7 @@ Module Name: #include #include "qnbitgemm.h" -#include "qnbitgemm_kernel_neon.h" +#include "sqnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index ccadd24ac1991..ef2909452faa8 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -64,47 +64,6 @@ SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlockCountK ); -// HQNBIT_CompFp16 declarations -#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) -void -HQ4BitGemmPackQuantBData_CompFp16( - size_t N, - size_t K, - size_t BlkLen, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -); - -void -HQ4BitBlkDequantBForHgemm_CompFp16( - size_t BlkLen, - MLAS_FP16* FpData, - const std::byte* QuantBData, - const MLAS_FP16* QuantBScale, - const std::byte* QuantBZeroPoint, - size_t CountN, - size_t K, - size_t BlockCountK -); - -void -HQ4BitGemmKernel_CompFp16( - const MLAS_FP16* A, - const MLAS_FP16* B, - const MLAS_FP16* Bias, - MLAS_FP16* C, - size_t CountM, - size_t CountN, - size_t K, - size_t lda, - size_t ldb, - size_t ldc -); - -#endif // !(defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)) - // SQNBIT_CompInt8 declarations void diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp index 31a499b8243af..cb1ca4a82d91a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -22,7 +22,7 @@ Module Name: #include #include "qnbitgemm.h" -#include "qnbitgemm_kernel_neon.h" +#include "sqnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index 73beb06a3cfad..f1acd99c7b693 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -22,7 +22,7 @@ Module Name: #include #include "qnbitgemm.h" -#include "qnbitgemm_kernel_neon.h" +#include "sqnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon diff --git a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch index 95a4e4650e9fe..d66d4952d4544 100644 --- a/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch +++ b/onnxruntime/core/providers/vsinpu/patches/mlas_crosscompiling.patch @@ -15,7 +15,7 @@ diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 28ae64c4d5..0c77e0ca78 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h -@@ -83,6 +83,9 @@ Abstract: +@@ -82,6 +82,9 @@ Abstract: #if (!defined(_MSC_VER)) || (_MSC_VER >= 1930) #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) @@ -25,7 +25,7 @@ index 28ae64c4d5..0c77e0ca78 100644 #if !defined(__APPLE__) // Had to temporary disable fp16 under APPLE ARM64, as compiling // the source files require a hardware specific compilation flag. -@@ -91,6 +94,7 @@ Abstract: +@@ -90,6 +93,7 @@ Abstract: #define MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -95,8 +95,17 @@ diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/ index b3c9461293..424c3b0441 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp -@@ -574,7 +574,7 @@ Return Value: - this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; +@@ -20,7 +20,7 @@ Abstract: + #include + #include + +-#if defined(MLAS_TARGET_POWER) ++#if defined(MLAS_TARGET_POWER) + #if defined(__linux__) + #include + #elif defined(_AIX) +@@ -536,7 +536,7 @@ Return Value: + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } -#if defined(__linux__) @@ -126,7 +135,7 @@ diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/pr index 2c6d23e4de..61aaacdfd6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc -@@ -133,7 +133,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { +@@ -132,7 +132,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { return Status::OK(); } diff --git a/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp index 64d229889214b..0cbfc49a4e34c 100644 --- a/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp @@ -17,16 +17,16 @@ #include "core/util/thread_utils.h" #include "core/platform/env_var_utils.h" -template -void RunQNBitGemmBenchmark(size_t BlkLen, - size_t M, size_t N, size_t K, - size_t Threads, - bool Symmetric, - bool HasBias, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, - benchmark::State& state) { +template +void RunSQNBitGemmBenchmark(size_t BlkLen, + size_t M, size_t N, size_t K, + size_t Threads, + bool Symmetric, + bool HasBias, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + benchmark::State& state) { if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { - state.SkipWithMessage("QNBitGemm is not available with the given configuration on the current machine."); + state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine."); return; } @@ -77,7 +77,7 @@ void RunQNBitGemmBenchmark(size_t BlkLen, tp.get()); } - MLAS_QNBIT_GEMM_DATA_PARAMS params{}; + MLAS_QNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; if (PackedQuantBData != nullptr) @@ -121,16 +121,14 @@ static void QNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"}); b->ArgsProduct({ - {128}, // BlkLen - {1, 4096}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{false}, int64_t{true}}, // HasBias - std::is_same_v - ? std::vector{int64_t{HQNBIT_CompFp16}} - : std::vector{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType + {128}, // BlkLen + {1}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{false}, int64_t{true}}, // HasBias + {int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType }); } @@ -142,19 +140,19 @@ template void QNBITGEMM_ENV(benchmark::State& state) { using onnxruntime::ParseEnvironmentVariableWithDefault; - const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_BLKLEN", 32); - const auto M = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_M", 1); - const auto N = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_N", 4096); - const auto K = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_K", 4096); - const auto Threads = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_THREADS", 1); - const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_SYMMETRIC", true); - const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_HAS_BIAS", false); - const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_COMPUTE_TYPE", + const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_BLKLEN", 32); + const auto M = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_M", 1); + const auto N = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_N", 4096); + const auto K = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_K", 4096); + const auto Threads = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_THREADS", 1); + const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_SYMMETRIC", true); + const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_HAS_BIAS", false); + const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_COMPUTE_TYPE", static_cast(SQNBIT_CompFp32)); - RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, - static_cast(ComputeType), - state); + RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, + static_cast(ComputeType), + state); std::ostringstream s; s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen