diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp index fffdd18189d55..f520734d82cd1 100644 --- a/ggml/src/vulkan-shaders/mul_mm.comp +++ b/ggml/src/vulkan-shaders/mul_mm.comp @@ -2,6 +2,9 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require +#extension GL_KHR_cooperative_matrix : require +#extension GL_KHR_memory_scope_semantics : require +#extension GL_EXT_shader_explicit_arithmetic_types : require #ifdef FLOAT16 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require @@ -152,12 +155,10 @@ void main() { uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; #endif - float sums[WMITER * TM * WNITER * TN]; - FLOAT_TYPE cache_a[WMITER * TM]; - FLOAT_TYPE cache_b[WNITER * TN]; + coopmat sums[WM * WN / 16 / 16]; - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; + [[unroll]] for (uint i = 0; i < WM * WN / 16 / 16; i++) { + sums[i] = coopmat(0.0); } [[unroll]] for (uint block = start_k; block < end_k; block += BK) { @@ -446,27 +447,14 @@ void main() { pos_a += BK / LOAD_VEC_A; pos_b += BK / LOAD_VEC_B; - for (uint i = 0; i < BK; i++) { - // Load from shared into cache - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]); - } - } + [[unroll]] for (uint i = 0; i < WM; i += 16) { + [[unroll]] for (uint j = 0; j < WN; j += 16) { + [[unroll]] for (uint k = 0; k < BK; k += 16) { + coopmat matA; + coopmat matB; + coopMatLoad(matA, buf_a, (warp_r * WM + i) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, buf_b, (warp_c * WN + j) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutColumnMajor); + sums[(i / 16) * (WN / 16) + (j / 16)] = coopMatMulAdd(matA, matB, sums[(i / 16) * (WN / 16) + (j / 16)]); } } } @@ -481,6 +469,19 @@ void main() { const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif +#if 1 +#ifndef MUL_MAT_ID + // XXX TODO this is missing bounds checking against p.M and p.N, + // which probably requires spilling to shared memory and doing scalar stores. + // But sums[] may not all fit in shared memory... + [[unroll]] for (uint i = 0; i < WM; i += 16) { + [[unroll]] for (uint j = 0; j < WN; j += 16) { + coopMatStore(sums[(i / 16) * (WN / 16) + (j / 16)], data_d, offsets + (dc + j) * p.stride_d + dr + i, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } + } +#endif +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { @@ -505,4 +506,5 @@ void main() { } } } +#endif }