Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kompute: add mul_mat_q4_k shader #10097

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
kompute-shaders/op_mul_mat_q8_0.comp
kompute-shaders/op_mul_mat_q4_0.comp
kompute-shaders/op_mul_mat_q4_1.comp
kompute-shaders/op_mul_mat_q4_k.comp
kompute-shaders/op_mul_mat_q6_k.comp
kompute-shaders/op_getrows_f32.comp
kompute-shaders/op_getrows_f16.comp
Expand Down Expand Up @@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
shaderop_mul_mat_q8_0.h
shaderop_mul_mat_q4_0.h
shaderop_mul_mat_q4_1.h
shaderop_mul_mat_q4_k.h
shaderop_mul_mat_q6_k.h
shaderop_getrows_f32.h
shaderop_getrows_f16.h
Expand Down
42 changes: 42 additions & 0 deletions ggml/src/ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "shaderop_mul_mat_q8_0.h"
#include "shaderop_mul_mat_q4_0.h"
#include "shaderop_mul_mat_q4_1.h"
#include "shaderop_mul_mat_q4_k.h"
#include "shaderop_mul_mat_q6_k.h"
#include "shaderop_mul_mat_mat_f32.h"
#include "shaderop_getrows_f32.h"
Expand Down Expand Up @@ -1067,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
}

static void ggml_vk_mul_mat_q4_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
int32_t ne1, int32_t r2, int32_t r3
) {
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);

struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
} pushConsts {
0, 0, 0,
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
};

std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) {
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
} else {
s_algo = komputeManager()->getAlgorithm(__func__);
s_algo->setTensors({inA, inB, out});
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
s_algo->setPushConstants<PushConstants>({pushConsts});
s_algo->updateDescriptors(s_kompute_context->pool.get());
}
seq.record<kp::OpAlgoDispatch>(s_algo);
}

static void ggml_vk_mul_mat_q6_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
Expand Down Expand Up @@ -1384,6 +1419,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_K:
return true;
default:
;
Expand Down Expand Up @@ -1635,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
);
break;
case GGML_TYPE_Q4_K:
ggml_vk_mul_mat_q4_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
);
break;
case GGML_TYPE_Q6_K:
ggml_vk_mul_mat_q6_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/kompute-shaders/common.comp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define TWOPI_F 6.283185307179586f

#define QK_K 256
#define K_SCALE_SIZE 12

#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
Expand Down Expand Up @@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
return reg;
}

#define sizeof_block_q4_k 144
struct block_q4_k {
float16_t d;
float16_t dmin;
uint8_t scales[K_SCALE_SIZE];
uint8_t qs[QK_K/2];
};

#define sizeof_block_q6_k 210
struct block_q6_k {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
Expand Down
133 changes: 133 additions & 0 deletions ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#version 450

#include "common.comp"

#define N_DST 4
#define SIZE_OF_BLOCK sizeof_block_q4_k

layout(local_size_x = 4) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;

layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };

layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne10;
int ne0;
int ne1;
int ne01;
int ne02;
int ne12;
int r2;
int r3;
} pcs;

void main() {
const uint16_t kmask1 = uint16_t(0x3f3f);
const uint16_t kmask2 = uint16_t(0x0f0f);
const uint16_t kmask3 = uint16_t(0xc0c0);

const uint ix = gl_SubgroupInvocationID/8; // 0...3
const uint it = gl_SubgroupInvocationID%8; // 0...7
const uint iq = it/4; // 0 or 1
const uint ir = it%4; // 0...3

const uint nb = pcs.ne00/QK_K;

const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;

const uint first_row = r0 * N_DST;
const uint ib_row = first_row * nb;

const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;

const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);

const uint xblk = ib_row + offset0 + pcs.inAOff;
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;

float yl[16];
float yh[16];
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
float all_sum = 0.f;

uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;

for (uint ib = ix; ib < nb; ib += 4) {
const uint blk_idx = ib + xblk;

float sumy[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
}

for (int row = 0; row < N_DST; row++) {
uint row_idx = row * nb;

uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);

uint16_t sc16[4];
sc16[0] = sc_0 & kmask1;
sc16[1] = sc_2 & kmask1;
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);

float acc1[4] = {0.f, 0.f, 0.f, 0.f};
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
acc1[0] += yl[i+0] * (q1 & 0x000F);
acc1[1] += yl[i+1] * (q1 & 0x0F00);
acc1[2] += yl[i+8] * (q1 & 0x00F0);
acc1[3] += yl[i+9] * (q1 & 0xF000);
acc2[0] += yh[i+0] * (q2 & 0x000F);
acc2[1] += yh[i+1] * (q2 & 0x0F00);
acc2[2] += yh[i+8] * (q2 & 0x00F0);
acc2[3] += yh[i+9] * (q2 & 0xF000);
}

uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );

float dall = float(inA[blk_idx + row_idx].d);
float dmin = float(inA[blk_idx + row_idx].dmin);
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
}

y4 += 4 * QK_K;
}

for (int row = 0; row < N_DST; ++row) {
all_sum = subgroupAdd(sumf[row]);
if (subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
}
}
}
Loading