Skip to content

Commit

Permalink
hacky KHR cooperative matrix prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffbolznv committed Nov 11, 2024
1 parent 60e17ce commit 3416010
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions ggml/src/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_KHR_cooperative_matrix : require
#extension GL_KHR_memory_scope_semantics : require
#extension GL_EXT_shader_explicit_arithmetic_types : require

#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
Expand Down Expand Up @@ -152,12 +155,10 @@ void main() {
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
#endif

float sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN];
coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> sums[WM * WN / 16 / 16];

[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = 0.0f;
[[unroll]] for (uint i = 0; i < WM * WN / 16 / 16; i++) {
sums[i] = coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator>(0.0);
}

[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
Expand Down Expand Up @@ -446,27 +447,14 @@ void main() {
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;

for (uint i = 0; i < BK; i++) {
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint j = 0; j < TM; j++) {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
}
}

[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
}
}
[[unroll]] for (uint i = 0; i < WM; i += 16) {
[[unroll]] for (uint j = 0; j < WN; j += 16) {
[[unroll]] for (uint k = 0; k < BK; k += 16) {
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> matB;
coopMatLoad(matA, buf_a, (warp_r * WM + i) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, buf_b, (warp_c * WN + j) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutColumnMajor);
sums[(i / 16) * (WN / 16) + (j / 16)] = coopMatMulAdd(matA, matB, sums[(i / 16) * (WN / 16) + (j / 16)]);
}
}
}
Expand All @@ -481,6 +469,19 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif

#if 1
#ifndef MUL_MAT_ID
// XXX TODO this is missing bounds checking against p.M and p.N,
// which probably requires spilling to shared memory and doing scalar stores.
// But sums[] may not all fit in shared memory...
[[unroll]] for (uint i = 0; i < WM; i += 16) {
[[unroll]] for (uint j = 0; j < WN; j += 16) {
coopMatStore(sums[(i / 16) * (WN / 16) + (j / 16)], data_d, offsets + (dc + j) * p.stride_d + dr + i, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
}
}
#endif
#else

[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {

Expand All @@ -505,4 +506,5 @@ void main() {
}
}
}
#endif
}

0 comments on commit 3416010

Please sign in to comment.