Skip to content

Commit

Permalink
fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Oct 30, 2024
1 parent 3780cf5 commit 3e095fc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 32 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ void SQNBITGEMM_ENV(benchmark::State& state) {
static_cast<int32_t>(CompFp32));

RunSQNBitGemmBenchmark<AType, BlkBitWidth>(BlkLen, M, N, K, Threads, Symmetric, HasBias,
static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(ComputeType),
state);
static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(ComputeType),
state);

std::ostringstream s;
s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen
Expand Down
49 changes: 19 additions & 30 deletions onnxruntime/test/mlas/unittest/test_sqnbitgemm_neon_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Module Name:

class MlasNeonFp16CastTest : public MlasTestBase {
private:

template <size_t count>
void TestFp16ToFp32() {
std::vector<unsigned short> src(count);
Expand Down Expand Up @@ -80,7 +79,7 @@ class MlasNeonFp16CastTest : public MlasTestBase {
class MlasNeonFp16PrepackTest : public MlasTestBase {
private:
std::random_device _rd; // a seed source for the random number engine
std::mt19937 _gen; // mersenne_twister_engine seeded with rd()
std::mt19937 _gen; // mersenne_twister_engine seeded with rd()
std::uniform_int_distribution<> _distrib;

MLAS_FORCEINLINE
Expand All @@ -91,9 +90,8 @@ class MlasNeonFp16PrepackTest : public MlasTestBase {
}
}

template<size_t Ldb>
MLAS_FORCEINLINE
void Transpose8x8(std::vector<uint8_t>& src, size_t n, size_t k, std::vector<uint8_t>& dst) {
template <size_t Ldb>
MLAS_FORCEINLINE void Transpose8x8(std::vector<uint8_t>& src, size_t n, size_t k, std::vector<uint8_t>& dst) {
for (size_t c = 0; c < 8; c++) {
for (size_t r = 0; r < 8; r++) {
size_t i = (n + c) * Ldb + r + k;
Expand All @@ -118,8 +116,7 @@ class MlasNeonFp16PrepackTest : public MlasTestBase {
}

template <size_t Ldb, size_t N, size_t K>
MLAS_FORCEINLINE
void Prepack(std::vector<uint8_t>& src, std::vector<uint8_t>& dst) {
MLAS_FORCEINLINE void Prepack(std::vector<uint8_t>& src, std::vector<uint8_t>& dst) {
size_t n = 0;
for (; n + 8 <= N; n += 8) {
for (size_t k = 0; k < Ldb; k += 8) {
Expand All @@ -134,9 +131,8 @@ class MlasNeonFp16PrepackTest : public MlasTestBase {
}
}

template<size_t Ldb, size_t N, size_t K>
MLAS_FORCEINLINE
void Check(std::vector<uint8_t>& packed, std::vector<uint8_t>& ref) {
template <size_t Ldb, size_t N, size_t K>
MLAS_FORCEINLINE void Check(std::vector<uint8_t>& packed, std::vector<uint8_t>& ref) {
size_t n = 0;
for (; n + 8 <= N; n += 8) {
for (size_t i = 0; i < K; i += 2) {
Expand Down Expand Up @@ -164,15 +160,14 @@ class MlasNeonFp16PrepackTest : public MlasTestBase {
std::vector<uint8_t> input(BufferSize), packed(BufferSize), ref(BufferSize);
InitializeBuffer(input);
MlasSQNBitGemmPackQuantBData(
N, K, Bits, BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE::CompFp16, input.data(), packed.data(), nullptr
);
N, K, Bits, BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE::CompFp16, input.data(), packed.data(), nullptr);
Prepack<Ldb, N, K>(input, ref);
Check<Ldb, N, K>(packed, ref);
}

public:
MlasNeonFp16PrepackTest()
: _gen(_rd()), _distrib(0, 255) {
: _gen(_rd()), _distrib(0, 255) {
}

static const char* GetTestSuiteName() {
Expand All @@ -197,7 +192,7 @@ class MlasNeonFp16PrepackTest : public MlasTestBase {
class MlasNeonFp16DequantBTest : public MlasTestBase {
private:
std::random_device _rd; // a seed source for the random number engine
std::mt19937 _gen; // mersenne_twister_engine seeded with rd()
std::mt19937 _gen; // mersenne_twister_engine seeded with rd()
std::uniform_int_distribution<> _distrib;
std::uniform_real_distribution<float> _distribFp;

Expand Down Expand Up @@ -271,9 +266,8 @@ class MlasNeonFp16DequantBTest : public MlasTestBase {
return std::abs(f0 - f1) <= f1 * rtol + atol;
}

template<size_t Ldb, size_t N, size_t K>
MLAS_FORCEINLINE
void Check(std::vector<MLAS_FP16>& target, std::vector<MLAS_FP16>& ref) {
template <size_t Ldb, size_t N, size_t K>
MLAS_FORCEINLINE void Check(std::vector<MLAS_FP16>& target, std::vector<MLAS_FP16>& ref) {
size_t n = 0;
for (; n + 8 <= N; n += 8) {
for (size_t i = 0; i < K; ++i) {
Expand Down Expand Up @@ -311,15 +305,14 @@ class MlasNeonFp16DequantBTest : public MlasTestBase {
GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp16(
BlkLen, dequant.data(), reinterpret_cast<std::byte*>(input.data()), scales.data(),
UseZeroPoints ? reinterpret_cast<std::byte*>(zero_points.data()) : nullptr,
N, K, BlkNum
);
N, K, BlkNum);
DequantB<N, K, BlkLen, UseZeroPoints>(input, ref, scales, zero_points);
Check<BlkLen * BlkNum, N, K>(dequant, ref);
}

public:
MlasNeonFp16DequantBTest()
: _gen(_rd()), _distrib(0, 255), _distribFp(0.5f, 2.0f) {
: _gen(_rd()), _distrib(0, 255), _distribFp(0.5f, 2.0f) {
}

static const char* GetTestSuiteName() {
Expand Down Expand Up @@ -355,7 +348,7 @@ class MlasNeonFp16DequantBTest : public MlasTestBase {
class MlasNeonFp16SQ4BitGemmKernelTest : public MlasTestBase {
private:
std::random_device _rd; // a seed source for the random number engine
std::mt19937 _gen; // mersenne_twister_engine seeded with rd()
std::mt19937 _gen; // mersenne_twister_engine seeded with rd()

MLAS_FORCEINLINE
void InitializeBuffer(std::vector<MLAS_FP16>& buffer, float min, float max) {
Expand All @@ -366,7 +359,6 @@ class MlasNeonFp16SQ4BitGemmKernelTest : public MlasTestBase {
}
}


MLAS_FORCEINLINE
bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) {
float f0 = v0.ToFloat(), f1 = v1.ToFloat();
Expand Down Expand Up @@ -401,9 +393,8 @@ class MlasNeonFp16SQ4BitGemmKernelTest : public MlasTestBase {
}
}

template<size_t Ldc, size_t M, size_t N>
MLAS_FORCEINLINE
void Check(std::vector<MLAS_FP16>& target, std::vector<MLAS_FP16>& ref) {
template <size_t Ldc, size_t M, size_t N>
MLAS_FORCEINLINE void Check(std::vector<MLAS_FP16>& target, std::vector<MLAS_FP16>& ref) {
for (size_t m = 0; m < M; ++m) {
for (size_t n = 0; n < N; ++n) {
size_t i = m * Ldc + n;
Expand All @@ -427,12 +418,10 @@ class MlasNeonFp16SQ4BitGemmKernelTest : public MlasTestBase {

if constexpr (M == 2 && N == 8) {
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompFp16_8N_2M(
A.data(), B.data(), UseBias ? bias.data() : nullptr, C.data(), K, K, ldb, N
);
A.data(), B.data(), UseBias ? bias.data() : nullptr, C.data(), K, K, ldb, N);
} else {
GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompFp16_Remainder(
A.data(), B.data(), UseBias ? bias.data() : nullptr, C.data(), M, N, K, K, ldb, N
);
A.data(), B.data(), UseBias ? bias.data() : nullptr, C.data(), M, N, K, K, ldb, N);
}

MatMul<M, N, K, ldb, UseBias>(A, B, bias, ref);
Expand All @@ -441,7 +430,7 @@ class MlasNeonFp16SQ4BitGemmKernelTest : public MlasTestBase {

public:
MlasNeonFp16SQ4BitGemmKernelTest()
: _gen(_rd()) {
: _gen(_rd()) {
}

static const char* GetTestSuiteName() {
Expand Down

0 comments on commit 3e095fc

Please sign in to comment.