diff --git a/ggml-metal.m b/ggml-metal.m index 8ca51c9262e3db..c7cf66c7f30f3f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -712,6 +712,7 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne00 == ne10); // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere + uint gqa = ne12/ne02; GGML_ASSERT(ne03 == ne13); // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs @@ -743,6 +744,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } @@ -845,6 +847,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { diff --git a/ggml-metal.metal b/ggml-metal.metal index e5f3623aeb1df0..e565b5efe35b2d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -343,14 +343,15 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre // N_DST, so this is another explicit assumption of the implementation. template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, - int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * nsg + sgitg) * nr; - device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb + im/(ne12/ne02)*(ne02/QK4_0); + const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0); + device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne12; float yl[16]; // src1 vector cache float sumf[nr]={0.f}; @@ -397,10 +398,11 @@ kernel void kernel_mul_mat_q4_0_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -413,10 +415,11 @@ kernel void kernel_mul_mat_q4_1_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( @@ -797,6 +800,7 @@ kernel void kernel_mul_mat_q2_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -808,7 +812,8 @@ kernel void kernel_mul_mat_q2_K_f32( const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K); + const uint offset0 = r2/gqa*(ne02/QK_K); + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -938,6 +943,7 @@ kernel void kernel_mul_mat_q3_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -949,8 +955,8 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r2 = tgpig.x; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K); + const uint offset0 = r2/gqa*(ne02/QK_K); + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02; float yl[16]; @@ -1054,6 +1060,7 @@ kernel void kernel_mul_mat_q3_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1065,8 +1072,8 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r2 = tgpig.x; const int row = 2 * r0 + sgitg; - - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + r2/(ne12/ne02)*(ne02/QK_K); + const uint offset0 = r2/gqa*(ne02/QK_K); + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offest0; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02; const int ix = tiisg/4; const int il = 4 * (tiisg%4);// 0, 4, 8, 12 @@ -1123,6 +1130,7 @@ kernel void kernel_mul_mat_q4_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1142,7 +1150,8 @@ kernel void kernel_mul_mat_q4_K_f32( const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K); + const uint offset0 = r2/gqa*(ne02/QK_K); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; float yl[16]; float yh[16]; @@ -1225,6 +1234,7 @@ kernel void kernel_mul_mat_q4_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1238,7 +1248,8 @@ kernel void kernel_mul_mat_q4_K_f32( const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K); + const uint offset0 = r2/gqa*(ne02/QK_K); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; float yl[8]; float yh[8]; @@ -1311,6 +1322,7 @@ kernel void kernel_mul_mat_q5_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1322,8 +1334,8 @@ kernel void kernel_mul_mat_q5_K_f32( const int r2 = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K); + const uint offset0 = r2/gqa*(ne02/QK_K); + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12; float sumf[2]={0.f}; @@ -1474,6 +1486,7 @@ kernel void kernel_mul_mat_q6_K_f32( constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], constant int64_t & ne0[[buffer(15)]], + constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1490,8 +1503,8 @@ kernel void kernel_mul_mat_q6_K_f32( const int r2 = tgpig.z; const int row = 2 * r0 + sgitg; - - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + r2/(ne12/ne02)*(ne02/QK_K); + const uint offset0 = r2/gqa*(ne02/QK_K); + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12; float sumf = 0; @@ -1792,6 +1805,7 @@ kernel void kernel_mul_mm(device const uchar * src0, constant int64_t & ne12, constant int64_t & ne0, constant int64_t & ne1, + constant uint & gqa, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -1818,7 +1832,8 @@ kernel void kernel_mul_mm(device const uchar * src0, } short il = (tiitg % THREAD_PER_ROW); - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + im/(ne12/ne02)*nb02) + il/nl; + uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12; @@ -1909,7 +1924,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ - constant int64_t &, constant int64_t &, threadgroup uchar *, uint3, uint, uint); + constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;