Skip to content

Commit

Permalink
metal: fix performance degradation from gqa
Browse files Browse the repository at this point in the history
Integers are slow on the GPU, and 64-bit divides are extremely slow.
In the context of GQA, we introduce a 64-bit divide that cannot be
optimized out by the compiler, which results in a decrease of ~8% in
inference performance. This commit fixes that issue by calculating a
part of the offset with a 32-bit divide. Naturally, this limits the
size of a single matrix to ~4GB. However, this limitation should
suffice for the near future.
  • Loading branch information
lshzh-ww committed Aug 15, 2023
1 parent 5f6de2a commit 3d8d51c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
3 changes: 3 additions & 0 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ void ggml_metal_graph_compute(

GGML_ASSERT(ne00 == ne10);
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
uint gqa = ne12/ne02;
GGML_ASSERT(ne03 == ne13);

// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
Expand Down Expand Up @@ -743,6 +744,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
Expand Down Expand Up @@ -845,6 +847,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];

if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
Expand Down
49 changes: 32 additions & 17 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,15 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr;
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb + im/(ne12/ne02)*(ne02/QK4_0);
const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0);
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12;
float yl[16]; // src1 vector cache
float sumf[nr]={0.f};
Expand Down Expand Up @@ -397,10 +398,11 @@ kernel void kernel_mul_mat_q4_0_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg);
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mat_q4_1_f32(
Expand All @@ -413,10 +415,11 @@ kernel void kernel_mul_mat_q4_1_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg);
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mat_f16_f32(
Expand Down Expand Up @@ -797,6 +800,7 @@ kernel void kernel_mul_mat_q2_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
Expand All @@ -808,7 +812,8 @@ kernel void kernel_mul_mat_q2_K_f32(

const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
Expand Down Expand Up @@ -938,6 +943,7 @@ kernel void kernel_mul_mat_q3_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
Expand All @@ -949,8 +955,8 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r2 = tgpig.x;

const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;

device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K);
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;

float yl[16];
Expand Down Expand Up @@ -1054,6 +1060,7 @@ kernel void kernel_mul_mat_q3_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
Expand All @@ -1065,8 +1072,8 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r2 = tgpig.x;

const int row = 2 * r0 + sgitg;

device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + r2/(ne12/ne02)*(ne02/QK_K);
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offest0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
const int ix = tiisg/4;
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
Expand Down Expand Up @@ -1123,6 +1130,7 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
Expand All @@ -1142,7 +1150,8 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
float yl[16];
float yh[16];
Expand Down Expand Up @@ -1225,6 +1234,7 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
Expand All @@ -1238,7 +1248,8 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K);
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
float yl[8];
float yh[8];
Expand Down Expand Up @@ -1311,6 +1322,7 @@ kernel void kernel_mul_mat_q5_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
Expand All @@ -1322,8 +1334,8 @@ kernel void kernel_mul_mat_q5_K_f32(
const int r2 = tgpig.z;

const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;

device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K);
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;

float sumf[2]={0.f};
Expand Down Expand Up @@ -1474,6 +1486,7 @@ kernel void kernel_mul_mat_q6_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
Expand All @@ -1490,8 +1503,8 @@ kernel void kernel_mul_mat_q6_K_f32(
const int r2 = tgpig.z;

const int row = 2 * r0 + sgitg;

device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + r2/(ne12/ne02)*(ne02/QK_K);
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;

float sumf = 0;
Expand Down Expand Up @@ -1792,6 +1805,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
constant int64_t & ne12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
Expand All @@ -1818,7 +1832,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
}

short il = (tiitg % THREAD_PER_ROW);
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + im/(ne12/ne02)*nb02) + il/nl;
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12;

Expand Down Expand Up @@ -1909,7 +1924,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows

typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
constant int64_t &, constant int64_t &, threadgroup uchar *, uint3, uint, uint);
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);

template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
Expand Down

0 comments on commit 3d8d51c

Please sign in to comment.