Skip to content

Commit

Permalink
[ARM] MatMulNBits Fp16 support - API change only (microsoft#22826)
Browse files Browse the repository at this point in the history
### Description
A break-down PR of microsoft#22651
Op API change only.
- add template to functions and classes that support fp32 and fp16
- rename functions, classes and files that support fp32 and fp16 from
SQNBxxx to QNBxxx


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
fajin-corp authored and ankitm3k committed Dec 11, 2024
1 parent 345a628 commit aecfb1f
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 97 deletions.
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ constexpr size_t A = 0,
};

typedef enum {
Level0, /*!< input fp32, accumulator fp32 */
Level1, /*!< input fp32, accumulator fp32 */
Level2, /*!< input fp16, accumulator fp16 */
Level3, /*!< input bf16, accumulator fp32 */
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Module Name:
#include <type_traits>

#include "fp16_common.h"
#include "sqnbitgemm.h"
#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"

namespace sqnbitgemm_neon
Expand Down Expand Up @@ -131,7 +131,7 @@ HQ4BitGemmPackQuantBData_CompFp16(
size_t N,
size_t K,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
Expand Down
7 changes: 0 additions & 7 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1049,13 +1049,6 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;

extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;

//
// Rotary embedding dispatch structure.
//
struct MLAS_ROPE_DISPATCH;
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;


//
// Quantized depthwise convolution kernels.
//
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;

// MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
}

#if defined(__linux__)
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ SQ4BitGemm_CompFp32(
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32(
GetMlasPlatform().QNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
BlkLen,
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);
Expand Down Expand Up @@ -808,7 +808,7 @@ GetQNBitGemm(QNBitGemmVariant variant)
{
switch (variant) {
case HQNBitGemmVariant_BitWidth4_CompFp16:
return HQ4BitGemm_CompFp16;
return nullptr;
default:
return nullptr;
}
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
//

/** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */
typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)(
typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)(
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;
SQ4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;

/** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */
typedef void(Q4BitGemmPackQuantBData_Fn)(
typedef void(SQ4BitGemmPackQuantBData_Fn)(
size_t N,
size_t K,
size_t BlkLen,
Expand Down Expand Up @@ -151,7 +151,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;
SQ4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;

/**
* @brief Gets the required byte alignment of the per-GEMM intermediate workspace.
Expand All @@ -164,7 +164,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;
SQ4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;

//
// SQNBIT_CompFp32 kernel function prototypes.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Module Name:
#include <cassert>

#include "qnbitgemm.h"
#include "qnbitgemm_kernel_neon.h"
#include "sqnbitgemm_kernel_neon.h"
#include "sqnbitgemm_q8_block.h"

namespace sqnbitgemm_neon
Expand Down
41 changes: 0 additions & 41 deletions onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,47 +64,6 @@ SQ4BitBlkDequantBForSgemm_CompFp32(
size_t BlockCountK
);

// HQNBIT_CompFp16 declarations
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
void
HQ4BitGemmPackQuantBData_CompFp16(
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
);

void
HQ4BitBlkDequantBForHgemm_CompFp16(
size_t BlkLen,
MLAS_FP16* FpData,
const std::byte* QuantBData,
const MLAS_FP16* QuantBScale,
const std::byte* QuantBZeroPoint,
size_t CountN,
size_t K,
size_t BlockCountK
);

void
HQ4BitGemmKernel_CompFp16(
const MLAS_FP16* A,
const MLAS_FP16* B,
const MLAS_FP16* Bias,
MLAS_FP16* C,
size_t CountM,
size_t CountN,
size_t K,
size_t lda,
size_t ldb,
size_t ldc
);

#endif // !(defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64))

// SQNBIT_CompInt8 declarations

void
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Module Name:
#include <cassert>

#include "qnbitgemm.h"
#include "qnbitgemm_kernel_neon.h"
#include "sqnbitgemm_kernel_neon.h"

namespace sqnbitgemm_neon
{
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Module Name:
#include <cassert>

#include "qnbitgemm.h"
#include "qnbitgemm_kernel_neon.h"
#include "sqnbitgemm_kernel_neon.h"
#include "sqnbitgemm_q8_block.h"

namespace sqnbitgemm_neon
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index 28ae64c4d5..0c77e0ca78 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -83,6 +83,9 @@ Abstract:
@@ -82,6 +82,9 @@ Abstract:

#if (!defined(_MSC_VER)) || (_MSC_VER >= 1930)
#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC)
Expand All @@ -25,7 +25,7 @@ index 28ae64c4d5..0c77e0ca78 100644
#if !defined(__APPLE__)
// Had to temporary disable fp16 under APPLE ARM64, as compiling
// the source files require a hardware specific compilation flag.
@@ -91,6 +94,7 @@ Abstract:
@@ -90,6 +93,7 @@ Abstract:

#define MLAS_F16VEC_INTRINSICS_SUPPORTED

Expand Down Expand Up @@ -95,8 +95,17 @@ diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/
index b3c9461293..424c3b0441 100644
--- a/onnxruntime/core/mlas/lib/platform.cpp
+++ b/onnxruntime/core/mlas/lib/platform.cpp
@@ -574,7 +574,7 @@ Return Value:
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;
@@ -20,7 +20,7 @@ Abstract:
#include <thread>
#include <mutex>

-#if defined(MLAS_TARGET_POWER)
+#if defined(MLAS_TARGET_POWER)
#if defined(__linux__)
#include <sys/auxv.h>
#elif defined(_AIX)
@@ -536,7 +536,7 @@ Return Value:
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
}

-#if defined(__linux__)
Expand Down Expand Up @@ -126,7 +135,7 @@ diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/pr
index 2c6d23e4de..61aaacdfd6 100644
--- a/onnxruntime/core/providers/cpu/math/matmul.cc
+++ b/onnxruntime/core/providers/cpu/math/matmul.cc
@@ -133,7 +133,7 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
@@ -132,7 +132,7 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {

return Status::OK();
}
Expand Down
60 changes: 29 additions & 31 deletions onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
#include "core/util/thread_utils.h"
#include "core/platform/env_var_utils.h"

template <typename AType, size_t BlkBitWidth>
void RunQNBitGemmBenchmark(size_t BlkLen,
size_t M, size_t N, size_t K,
size_t Threads,
bool Symmetric,
bool HasBias,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
benchmark::State& state) {
template <size_t BlkBitWidth>
void RunSQNBitGemmBenchmark(size_t BlkLen,
size_t M, size_t N, size_t K,
size_t Threads,
bool Symmetric,
bool HasBias,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
benchmark::State& state) {
if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) {
state.SkipWithMessage("QNBitGemm is not available with the given configuration on the current machine.");
state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine.");
return;
}

Expand Down Expand Up @@ -77,7 +77,7 @@ void RunQNBitGemmBenchmark(size_t BlkLen,
tp.get());
}

MLAS_QNBIT_GEMM_DATA_PARAMS<AType> params{};
MLAS_QNBIT_GEMM_DATA_PARAMS<float> params{};
params.A = A.data();
params.lda = K;
if (PackedQuantBData != nullptr)
Expand Down Expand Up @@ -121,16 +121,14 @@ static void QNBitGemmArgs(benchmark::internal::Benchmark* b) {
b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"});

b->ArgsProduct({
{128}, // BlkLen
{1, 4096}, // M
{4096, 11008}, // N
{4096, 11008}, // K
{1, 8}, // Threads
{int64_t{false}, int64_t{true}}, // Symmetric
{int64_t{false}, int64_t{true}}, // HasBias
std::is_same_v<AType, MLAS_FP16>
? std::vector<int64_t>{int64_t{HQNBIT_CompFp16}}
: std::vector<int64_t>{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType
{128}, // BlkLen
{1}, // M
{4096, 11008}, // N
{4096, 11008}, // K
{1, 8}, // Threads
{int64_t{false}, int64_t{true}}, // Symmetric
{int64_t{false}, int64_t{true}}, // HasBias
{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType
});
}

Expand All @@ -142,19 +140,19 @@ template <typename AType, size_t BlkBitWidth>
void QNBITGEMM_ENV(benchmark::State& state) {
using onnxruntime::ParseEnvironmentVariableWithDefault;

const auto BlkLen = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_BLKLEN", 32);
const auto M = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_M", 1);
const auto N = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_N", 4096);
const auto K = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_K", 4096);
const auto Threads = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_THREADS", 1);
const auto Symmetric = ParseEnvironmentVariableWithDefault<bool>("ORT_QNBITGEMM_SYMMETRIC", true);
const auto HasBias = ParseEnvironmentVariableWithDefault<bool>("ORT_QNBITGEMM_HAS_BIAS", false);
const auto ComputeType = ParseEnvironmentVariableWithDefault<int32_t>("ORT_QNBITGEMM_COMPUTE_TYPE",
const auto BlkLen = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_BLKLEN", 32);
const auto M = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_M", 1);
const auto N = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_N", 4096);
const auto K = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_K", 4096);
const auto Threads = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_THREADS", 1);
const auto Symmetric = ParseEnvironmentVariableWithDefault<bool>("ORT_SQNBITGEMM_SYMMETRIC", true);
const auto HasBias = ParseEnvironmentVariableWithDefault<bool>("ORT_SQNBITGEMM_HAS_BIAS", false);
const auto ComputeType = ParseEnvironmentVariableWithDefault<int32_t>("ORT_SQNBITGEMM_COMPUTE_TYPE",
static_cast<int32_t>(SQNBIT_CompFp32));

RunQNBitGemmBenchmark<AType, BlkBitWidth>(BlkLen, M, N, K, Threads, Symmetric, HasBias,
static_cast<MLAS_QNBIT_GEMM_COMPUTE_TYPE>(ComputeType),
state);
RunSQNBitGemmBenchmark<BlkBitWidth>(BlkLen, M, N, K, Threads, Symmetric, HasBias,
static_cast<MLAS_QNBIT_GEMM_COMPUTE_TYPE>(ComputeType),
state);

std::ostringstream s;
s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen
Expand Down

0 comments on commit aecfb1f

Please sign in to comment.