forked from ggerganov/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
kompute: add mul_mat_q4_k shader (ggerganov#10097)
This is a more or less direct translation from the Metal implementation to GLSL. Signed-off-by: Sergio Lopez <slp@redhat.com>
- Loading branch information
Showing
4 changed files
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
#version 450 | ||
|
||
#include "common.comp" | ||
|
||
#define N_DST 4 | ||
#define SIZE_OF_BLOCK sizeof_block_q4_k | ||
|
||
layout(local_size_x = 4) in; | ||
layout(local_size_y = 8) in; | ||
layout(local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; }; | ||
layout (binding = 1) readonly buffer tensorInB { float inB[]; }; | ||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; | ||
|
||
layout (push_constant) uniform parameter { | ||
uint inAOff; | ||
uint inBOff; | ||
uint outOff; | ||
int ne00; | ||
int ne10; | ||
int ne0; | ||
int ne1; | ||
int ne01; | ||
int ne02; | ||
int ne12; | ||
int r2; | ||
int r3; | ||
} pcs; | ||
|
||
void main() { | ||
const uint16_t kmask1 = uint16_t(0x3f3f); | ||
const uint16_t kmask2 = uint16_t(0x0f0f); | ||
const uint16_t kmask3 = uint16_t(0xc0c0); | ||
|
||
const uint ix = gl_SubgroupInvocationID/8; // 0...3 | ||
const uint it = gl_SubgroupInvocationID%8; // 0...7 | ||
const uint iq = it/4; // 0 or 1 | ||
const uint ir = it%4; // 0...3 | ||
|
||
const uint nb = pcs.ne00/QK_K; | ||
|
||
const uint r0 = gl_WorkGroupID.x; | ||
const uint r1 = gl_WorkGroupID.y; | ||
const uint im = gl_WorkGroupID.z; | ||
|
||
const uint first_row = r0 * N_DST; | ||
const uint ib_row = first_row * nb; | ||
|
||
const uint i12 = im%pcs.ne12; | ||
const uint i13 = im/pcs.ne12; | ||
|
||
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02); | ||
|
||
const uint xblk = ib_row + offset0 + pcs.inAOff; | ||
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; | ||
|
||
float yl[16]; | ||
float yh[16]; | ||
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f}; | ||
float all_sum = 0.f; | ||
|
||
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir; | ||
|
||
for (uint ib = ix; ib < nb; ib += 4) { | ||
const uint blk_idx = ib + xblk; | ||
|
||
float sumy[4] = {0.f, 0.f, 0.f, 0.f}; | ||
for (int i = 0; i < 8; ++i) { | ||
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0]; | ||
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8]; | ||
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0]; | ||
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8]; | ||
} | ||
|
||
for (int row = 0; row < N_DST; row++) { | ||
uint row_idx = row * nb; | ||
|
||
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0); | ||
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2); | ||
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4); | ||
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6); | ||
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8); | ||
|
||
uint16_t sc16[4]; | ||
sc16[0] = sc_0 & kmask1; | ||
sc16[1] = sc_2 & kmask1; | ||
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2); | ||
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2); | ||
|
||
float acc1[4] = {0.f, 0.f, 0.f, 0.f}; | ||
float acc2[4] = {0.f, 0.f, 0.f, 0.f}; | ||
for (int i = 0; i < 8; i += 2) { | ||
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i); | ||
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i); | ||
acc1[0] += yl[i+0] * (q1 & 0x000F); | ||
acc1[1] += yl[i+1] * (q1 & 0x0F00); | ||
acc1[2] += yl[i+8] * (q1 & 0x00F0); | ||
acc1[3] += yl[i+9] * (q1 & 0xF000); | ||
acc2[0] += yh[i+0] * (q2 & 0x000F); | ||
acc2[1] += yh[i+1] * (q2 & 0x0F00); | ||
acc2[2] += yh[i+8] * (q2 & 0x00F0); | ||
acc2[3] += yh[i+9] * (q2 & 0xF000); | ||
} | ||
|
||
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF); | ||
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 ); | ||
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF); | ||
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 ); | ||
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF); | ||
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 ); | ||
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF); | ||
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 ); | ||
|
||
float dall = float(inA[blk_idx + row_idx].d); | ||
float dmin = float(inA[blk_idx + row_idx].dmin); | ||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 + | ||
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f + | ||
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 + | ||
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) - | ||
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7); | ||
} | ||
|
||
y4 += 4 * QK_K; | ||
} | ||
|
||
for (int row = 0; row < N_DST; ++row) { | ||
all_sum = subgroupAdd(sumf[row]); | ||
if (subgroupElect()) { | ||
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum; | ||
} | ||
} | ||
} |