Skip to content

Commit

Permalink
backend cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels (#9921)
Browse files Browse the repository at this point in the history
* backend-cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
  • Loading branch information
chaxu01 and slaren authored Nov 15, 2024
1 parent ae8de6d commit 1607a5e
Show file tree
Hide file tree
Showing 9 changed files with 271 additions and 20 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,10 @@ ggml/src/ggml-cuda/%.o: \
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -mtgpu -c -o $@ $<
endif # GGML_MUSA

ifndef GGML_NO_CPU_AARCH64
MK_CPPFLAGS += -DGGML_USE_CPU_AARCH64
endif

ifdef GGML_METAL
MK_CPPFLAGS += -DGGML_USE_METAL
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
Expand Down
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ else()
endif()

option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)

option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
Expand Down
3 changes: 3 additions & 0 deletions ggml/include/ggml-cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ extern "C" {
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
#endif

GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
GGML_BACKEND_API bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft);

#ifdef __cplusplus
}
#endif
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ else()
message(STATUS "Unknown architecture")
endif()

if (GGML_CPU_AARCH64)
message(STATUS "Using runtime weight conversion of Q4_0 to Q4_0_x_x to enable optimized GEMM/GEMV kernels")
add_compile_definitions(GGML_USE_CPU_AARCH64)
endif()

target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")

Expand Down
144 changes: 144 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu-aarch64.c
Original file line number Diff line number Diff line change
Expand Up @@ -3385,3 +3385,147 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
}
}

// FIXME: this code is duplicated from ggml-aarch64.c
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
block_q4_0x4 out;

for (int i = 0; i < 4; i++) {
out.d[i] = in[i].d;
}

for (int i = 0; i < QK4_0 * 2; i++) {
int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (i % blck_size_interleave);

out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
}

return out;
}

// interleave 8 block_q4_0s in blocks of blck_size_interleave
// returns an interleaved block_q4_0x8
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
block_q4_0x8 out;

for (int i = 0; i < 8; i++) {
out.d[i] = in[i].d;
}

for (int i = 0; i < QK4_0 * 4; i++) {
int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
src_offset += (i % blck_size_interleave);

out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
}

return out;
}

static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);

block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
const block_q4_0 * src = (const block_q4_0 *)data;
block_q4_0 dst_tmp[4];
int nrow = t->ne[1]; // Number of rows
int nrows_interleaved = 4;
int nblocks = t->ne[0] / QK4_0;

GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));

if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1;
}

for (int b = 0; b < nrow; b += nrows_interleaved) {
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++) {
dst_tmp[i] = src[x + i * nblocks];
}
*dst++ = make_block_q4_0x4(dst_tmp, interleave_block, 0x88);
}
src += nrows_interleaved * nblocks;
}
return 0;

GGML_UNUSED(data_size);
}

static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 8);

block_q4_0x8 * dst = (block_q4_0x8*)t->data;
const block_q4_0 * src = (const block_q4_0*) data;
block_q4_0 dst_tmp[8];
int nrow = t->ne[1]; // Number of rows
int nrows_interleaved = 8;
int nblocks = t->ne[0] / QK4_0;

GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));

if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1;
}

for (int b = 0; b < nrow; b += nrows_interleaved) {
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) {
dst_tmp[i] = src[x + i * nblocks];
}
*dst++ = make_block_q4_0x8(dst_tmp, interleave_block, 0x88);
}
src += nrows_interleaved * nblocks;
}
return 0;

GGML_UNUSED(data_size);
}

// Prepare for optimized kernels if applicable
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
if (cur->type == repack_type) {
memcpy(cur->data, data, data_size);
return;
}

GGML_ASSERT(cur->type == GGML_TYPE_Q4_0);

switch (repack_type) {
case GGML_TYPE_Q4_0_8_8:
repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
break;
case GGML_TYPE_Q4_0_4_8:
repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
break;
case GGML_TYPE_Q4_0_4_4:
repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
break;
default:
GGML_ABORT("Unsupported type");
}
}

enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
if (cur->type == GGML_TYPE_Q4_0) {
// TODO: enable for AVX2 - currently disabled due to bad gemv performance
if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
return GGML_TYPE_Q4_0_8_8;
}
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
return GGML_TYPE_Q4_0_4_8;
}
if (ggml_cpu_has_neon()) {
return GGML_TYPE_Q4_0_4_4;
}
}

return cur->type;
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);

void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);

#ifdef __cplusplus
}
#endif
Expand Down
23 changes: 13 additions & 10 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -7330,6 +7330,7 @@ static void ggml_compute_forward_group_norm(
static void ggml_compute_forward_mul_mat_one_chunk(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
const enum ggml_type type,
const int64_t num_rows_per_vec_dot,
const int64_t ir0_start,
const int64_t ir0_end,
Expand All @@ -7341,8 +7342,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(

GGML_TENSOR_BINARY_OP_LOCALS

const enum ggml_type type = src0->type;

const bool src1_cont = ggml_is_contiguous(src1);

ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
Expand Down Expand Up @@ -7430,7 +7429,11 @@ static void ggml_compute_forward_mul_mat(
const int ith = params->ith;
const int nth = params->nth;

const enum ggml_type type = src0->type;
enum ggml_type type = src0->type;

if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
type = (enum ggml_type)(intptr_t)src0->extra;
}

enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
Expand Down Expand Up @@ -7469,15 +7472,15 @@ static void ggml_compute_forward_mul_mat(
if (src1_cont) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
nb01/ggml_type_size(type),
(const char *)src1->data + i12*nb12 + i13*nb13,
nb11/ggml_type_size(src1->type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
ith, nth,
src0->type,
type,
src1->type,
dst->type))
goto UseGgmlGemm1;
Expand Down Expand Up @@ -7530,15 +7533,15 @@ UseGgmlGemm1:;

for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
nb01/ggml_type_size(type),
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
ith, nth,
src0->type,
type,
vec_dot_type,
dst->type))
goto UseGgmlGemm2;
Expand Down Expand Up @@ -7623,7 +7626,7 @@ UseGgmlGemm2:;
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);

ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);

if (nth >= nchunk0 * nchunk1) {
break;
Expand Down
Loading

0 comments on commit 1607a5e

Please sign in to comment.