From ebb1192fd62efaffbb25b9bd3d0eeef2b538e931 Mon Sep 17 00:00:00 2001 From: Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> Date: Sat, 16 Nov 2024 02:57:00 +0530 Subject: [PATCH 01/17] Make updates to fix issues with clang-cl builds while using AVX512 flags (llama/10314) --- src/ggml-cpu/CMakeLists.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/ggml-cpu/CMakeLists.txt b/src/ggml-cpu/CMakeLists.txt index 30de6c99a..cef41a074 100644 --- a/src/ggml-cpu/CMakeLists.txt +++ b/src/ggml-cpu/CMakeLists.txt @@ -143,14 +143,23 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW if (GGML_AVX512_VBMI) add_compile_definitions($<$:__AVX512VBMI__>) add_compile_definitions($<$:__AVX512VBMI__>) + if (CMAKE_C_COMPILER_ID STREQUAL "Clang") + list(APPEND ARCH_FLAGS -mavx512vbmi) + endif() endif() if (GGML_AVX512_VNNI) add_compile_definitions($<$:__AVX512VNNI__>) add_compile_definitions($<$:__AVX512VNNI__>) + if (CMAKE_C_COMPILER_ID STREQUAL "Clang") + list(APPEND ARCH_FLAGS -mavx512vnni) + endif() endif() if (GGML_AVX512_BF16) add_compile_definitions($<$:__AVX512BF16__>) add_compile_definitions($<$:__AVX512BF16__>) + if (CMAKE_C_COMPILER_ID STREQUAL "Clang") + list(APPEND ARCH_FLAGS -mavx512bf16) + endif() endif() if (GGML_AMX_TILE) add_compile_definitions($<$:__AMX_TILE__>) From fa0b72b2801453059568d7afe50060c4c3dcb90c Mon Sep 17 00:00:00 2001 From: Dan Johansson Date: Sat, 16 Nov 2024 01:53:37 +0100 Subject: [PATCH 02/17] ggml : optimize Q4_0 into Q4_0_X_Y repack (llama/10324) --- src/ggml-aarch64.c | 57 +++++++++++++++++++++++++-------- src/ggml-cpu/ggml-cpu-aarch64.c | 57 +++++++++++++++++++++++++-------- 2 files changed, 86 insertions(+), 28 deletions(-) diff --git a/src/ggml-aarch64.c b/src/ggml-aarch64.c index 688b17b45..013912051 100644 --- a/src/ggml-aarch64.c +++ b/src/ggml-aarch64.c @@ -8,19 +8,42 @@ #define UNUSED GGML_UNUSED -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x4 out; for (int i = 0; i < 4; i++) { out.d[i] = in[i].d; } - for (int i = 0; i < QK4_0 * 2; i++) { - int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); + const int end = QK4_0 * 2 / blck_size_interleave; - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + if (blck_size_interleave == 8) { + const uint64_t xor_mask = 0x8888888888888888ULL; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + // Using memcpy to avoid unaligned memory accesses + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + } else if (blck_size_interleave == 4) { + const uint32_t xor_mask = 0x88888888; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint32_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); } return out; @@ -30,19 +53,25 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_in // returns an interleaved block_q4_0x8 // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x8 out; for (int i = 0; i < 8; i++) { out.d[i] = in[i].d; } - for (int i = 0; i < QK4_0 * 4; i++) { - int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); + const int end = QK4_0 * 4 / blck_size_interleave; + const uint64_t xor_mask = 0x8888888888888888ULL; + + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); } return out; @@ -71,11 +100,11 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds } if (nrows_interleaved == 8) { - *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); + *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave); out_ptr = (block_q4_0x8 *) out_ptr + 1; } else if (nrows_interleaved == 4) { - *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); + *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave); out_ptr = (block_q4_0x4 *) out_ptr + 1; } } diff --git a/src/ggml-cpu/ggml-cpu-aarch64.c b/src/ggml-cpu/ggml-cpu-aarch64.c index b753ba767..96a16dfba 100644 --- a/src/ggml-cpu/ggml-cpu-aarch64.c +++ b/src/ggml-cpu/ggml-cpu-aarch64.c @@ -3387,19 +3387,42 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * } // FIXME: this code is duplicated from ggml-aarch64.c -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x4 out; for (int i = 0; i < 4; i++) { out.d[i] = in[i].d; } - for (int i = 0; i < QK4_0 * 2; i++) { - int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); + const int end = QK4_0 * 2 / blck_size_interleave; - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + if (blck_size_interleave == 8) { + const uint64_t xor_mask = 0x8888888888888888ULL; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + // Using memcpy to avoid unaligned memory accesses + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + } else if (blck_size_interleave == 4) { + const uint32_t xor_mask = 0x88888888; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint32_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); } return out; @@ -3409,19 +3432,25 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_in // returns an interleaved block_q4_0x8 // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x8 out; for (int i = 0; i < 8; i++) { out.d[i] = in[i].d; } - for (int i = 0; i < QK4_0 * 4; i++) { - int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); + const int end = QK4_0 * 4 / blck_size_interleave; + const uint64_t xor_mask = 0x8888888888888888ULL; + + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); } return out; @@ -3449,7 +3478,7 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q4_0x4(dst_tmp, interleave_block, 0x88); + *dst++ = make_block_q4_0x4(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } @@ -3480,7 +3509,7 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, for (int i = 0; i < nrows_interleaved; i++ ) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q4_0x8(dst_tmp, interleave_block, 0x88); + *dst++ = make_block_q4_0x8(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } From c759a72a14cfe518fa2f09cd0cdc544008d9183b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 16 Nov 2024 00:26:57 -0600 Subject: [PATCH 03/17] vulkan: Optimize some mat-vec mul quant shaders (llama/10296) Compute two result elements per workgroup (for Q{4,5}_{0,1}). This reuses the B loads across the rows and also reuses some addressing calculations. This required manually partially unrolling the loop, since the compiler is less willing to unroll outer loops. Add bounds-checking on the last iteration of the loop. I think this was at least partly broken before. Optimize the Q4_K shader to vectorize most loads and reduce the number of bit twiddling instructions. --- src/ggml-vulkan/ggml-vulkan.cpp | 83 +++++----- .../vulkan-shaders/mul_mat_vec.comp | 97 ++++++++--- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 155 ++++++++++-------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 6 +- 4 files changed, 210 insertions(+), 131 deletions(-) diff --git a/src/ggml-vulkan/ggml-vulkan.cpp b/src/ggml-vulkan/ggml-vulkan.cpp index 04d671aeb..19951d9b6 100644 --- a/src/ggml-vulkan/ggml-vulkan.cpp +++ b/src/ggml-vulkan/ggml-vulkan.cpp @@ -1365,47 +1365,48 @@ static void ggml_vk_load_shaders(vk_device& device) { } // mul mat vec - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + // computing two rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0. + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index d3ccba7fc..ddbac5d2c 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -3,54 +3,107 @@ #ifdef FLOAT16 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_EXT_null_initializer : enable #include "mul_mat_vec_base.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; -shared FLOAT_TYPE tmp[BLOCK_SIZE]; +uint a_offset, b_offset, d_offset, y_offset; -void main() { - const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; - const uint tid = gl_LocalInvocationID.x; +shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; - // There are not enough cols to use all threads - if (tid >= p.ncols) { - return; +void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) +{ + const uint col = i*BLOCK_SIZE + 2*tid; + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index + + // Check if the second of the pair of elements is OOB, and don't fetch B or + // accumulate it. We still fetch a pair of elements for A, which is fine for + // quantized formats since they'll be within the same block. We should + // probably skip fetching the second element for F16/F32, but as of now we + // still do. + const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + + FLOAT_TYPE b0 = 0, b1 = 0; + b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]); + if (!OOB) { + b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]); } + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index + + const vec2 v = dequantize(ib, iqs, a_offset); + + // matrix multiplication + temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]); + if (!OOB) { + temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]); + } + } +} - const uint block_size = min(p.ncols, BLOCK_SIZE); +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; - uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; - const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - tmp[tid] = FLOAT_TYPE(0.0f); + FLOAT_TYPE temp[NUM_ROWS] = {}; - [[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) { - const uint col = i*block_size + 2*tid; - const uint ib = (row*p.ncols + col)/QUANT_K; // block index - const uint iqs = (col%QUANT_K)/QUANT_R; // quant index - const uint iybs = col - col%QUANT_K; // y block start index + const int unroll_count = 8; - vec2 v = dequantize(ib, iqs, a_offset / QUANT_K); + const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0; + const uint unrolled_iters = num_iters & ~(2*unroll_count - 1); - // matrix multiplication - tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid])); + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i, false); + i += 2; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i, true); + i += 2; } // sum up partial sums and write back result + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[n][tid] = temp[n]; + } barrier(); - [[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) { + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { if (tid < s) { - tmp[tid] += tmp[tid + s]; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[n][tid] += tmpsh[n][tid + s]; + } } barrier(); } if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); + } + } +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + compute_outputs(first_row, p.stride_d - first_row); } } diff --git a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index d91e00e10..b7c9b722d 100644 --- a/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -1,11 +1,34 @@ #version 450 +#extension GL_EXT_shader_explicit_arithmetic_types : require + #include "mul_mat_vec_base.comp" layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE tmp[32]; +// Declare aliased versions of A and B bindings that can use 16b/32b loads for +// the quantized values, and vec4 loads for B. +struct block_q4_K_u32 +{ + f16vec2 d; + uint32_t scales[3*QUANT_K/64/4]; + uint32_t qs[QUANT_K/2/4]; +}; + +struct block_q4_K_u16 +{ + f16vec2 d; + uint16_t scales[3*QUANT_K/64/2]; + uint16_t qs[QUANT_K/2/2]; +}; + +layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];}; +layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads void main() { const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; @@ -31,79 +54,81 @@ void main() { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp + FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - - const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f); - const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f); - const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f); - const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f); - const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); - const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); - const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); - const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); - -#if K_QUANTS_PER_ITERATION == 2 - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4); - const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf); - const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf); - const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4); - const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]), q4_0, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), q4_1, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3))); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6, FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7))); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]), q4_8, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), q4_9, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11))); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7, - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7, - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7, - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7))))))))))))))); - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx])); -#else - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); + f16vec2 d = data_a[ib0 + i].d; + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ]; + uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2]; + uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4]; + uvec4 scale0 = uvec4(unpack8(scale0_u32)); + uvec4 scale4 = uvec4(unpack8(scale4_u32)); + uvec4 scale8 = uvec4(unpack8(scale8_u32)); + + const uint32_t sc0 = ( scale0.x & 0x3f); + const uint32_t sc1 = ( scale0.y & 0x3f); + const uint32_t sc2 = ( scale4.x & 0x3f); + const uint32_t sc3 = ( scale4.y & 0x3f); + const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); + const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); + const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); + const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); + + uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4]; + uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16]; + + uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); + uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); + uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); + uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); + + const uint32_t q4_0 = qs0_lo4.x; + const uint32_t q4_1 = qs0_lo4.y; + const uint32_t q4_2 = qs0_lo4.z; + const uint32_t q4_3 = qs0_lo4.w; + const uint32_t q4_4 = qs0_hi4.x; + const uint32_t q4_5 = qs0_hi4.y; + const uint32_t q4_6 = qs0_hi4.z; + const uint32_t q4_7 = qs0_hi4.w; + const uint32_t q4_8 = qs64_lo4.x; + const uint32_t q4_9 = qs64_lo4.y; + const uint32_t q4_10 = qs64_lo4.z; + const uint32_t q4_11 = qs64_lo4.w; + const uint32_t q4_12 = qs64_hi4.x; + const uint32_t q4_13 = qs64_hi4.y; + const uint32_t q4_14 = qs64_hi4.z; + const uint32_t q4_15 = qs64_hi4.w; + + B_TYPE_VEC4 by10 = data_b_v4[(b_offset + y1_idx) / 4]; + B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8]; + B_TYPE_VEC4 by20 = data_b_v4[(b_offset + y2_idx) / 4]; + B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8]; + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); const FLOAT_TYPE smin = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7, - + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7))))))); - - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + - sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f), - fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx])); -#endif + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp)); } + tmp[gl_LocalInvocationID.x] = temp; + // sum up partial sums and write back result barrier(); [[unroll]] for (uint s = 16; s > 0; s >>= 1) { diff --git a/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 5c84f473f..fe3e4cb39 100644 --- a/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -317,10 +317,10 @@ void process_shaders() { std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; - string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); // Dequant shaders if (tname != "f16") { From 92bc9e247ea8981ff74dbcf31df547ecb657b6f5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Nov 2024 17:58:56 +0200 Subject: [PATCH 04/17] llamafile : fix include path (llama/0) ggml-ci --- src/ggml-cpu/llamafile/sgemm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ggml-cpu/llamafile/sgemm.cpp b/src/ggml-cpu/llamafile/sgemm.cpp index da4146ec4..b2ce2e664 100644 --- a/src/ggml-cpu/llamafile/sgemm.cpp +++ b/src/ggml-cpu/llamafile/sgemm.cpp @@ -50,7 +50,8 @@ #include "sgemm.h" #include "ggml-impl.h" -#include "ggml-cpu-impl.h" +// hack until moved into the CPU backend +#include "../ggml-cpu-impl.h" #include "ggml-quants.h" #ifdef _MSC_VER From 39de6533f6ef694b55004e6fad6fc7508c7e7094 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Nov 2024 21:32:41 +0200 Subject: [PATCH 05/17] ggml : fix compile warnings (llama/0) ggml-ci --- src/ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ggml.c b/src/ggml.c index 4a478fcaa..ee72a173e 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -5501,7 +5501,7 @@ static void ggml_compute_backward( fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", __func__, ggml_unary_op_name(ggml_get_unary_op(tensor))); GGML_ABORT("fatal error"); - } break; + } //break; } } break; case GGML_OP_CROSS_ENTROPY_LOSS: { @@ -5517,7 +5517,7 @@ static void ggml_compute_backward( default: { fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); GGML_ABORT("fatal error"); - } break; + } //break; } GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0])); From 04260625163b26b45605e18310a8650f083da4dd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Nov 2024 21:38:01 +0200 Subject: [PATCH 06/17] ggml : adapt AMX to tensor->grad removal (llama/0) ggml-ci --- src/ggml-amx/ggml-amx.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/ggml-amx/ggml-amx.cpp b/src/ggml-amx/ggml-amx.cpp index 37da98539..8568e7965 100644 --- a/src/ggml-amx/ggml-amx.cpp +++ b/src/ggml-amx/ggml-amx.cpp @@ -317,8 +317,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st const enum ggml_type type = src0->type; const int64_t ne0 = op->ne[0]; - bool is_training = src0->grad || src1->grad; - // amx kernels enables for Q4_0, Q4_1, Q8_0, F16 // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256 bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16); @@ -326,7 +324,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st bool can_use_amx = is_contiguous_2d(src0) && // src0 must be contiguous is_contiguous_2d(src1) && // src1 must be contiguous - !is_training && // inference only src1->type == GGML_TYPE_F32 && // src1 must be float32 has_amx_kernels && // with amx kernel impls ne0 % (TILE_N * 2) == 0; // out_features is 32x From 1bf5d5f58b42343b653dc0485ec22e56fa2b1c94 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 Nov 2024 23:40:39 +0200 Subject: [PATCH 07/17] ggml : inttypes.h -> cinttypes (llama/0) ggml-ci --- src/ggml-opt.cpp | 2 +- tests/test-opt.cpp | 2 +- tests/test-quantize-perf.cpp | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ggml-opt.cpp b/src/ggml-opt.cpp index 808aa0d02..040205a31 100644 --- a/src/ggml-opt.cpp +++ b/src/ggml-opt.cpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp index 4abe85c74..f90c92b4b 100644 --- a/tests/test-opt.cpp +++ b/tests/test-opt.cpp @@ -5,7 +5,7 @@ #include "ggml-opt.h" #include -#include +#include #include #include #include diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp index ac0d12714..288288493 100644 --- a/tests/test-quantize-perf.cpp +++ b/tests/test-quantize-perf.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include From 23d71e191075f6c621ae6ebd218a2e2383b8b9ec Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Sun, 17 Nov 2024 07:31:17 +0100 Subject: [PATCH 08/17] ggml : fix possible buffer use after free in sched reserve (llama/9930) --- src/ggml-backend.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ggml-backend.cpp b/src/ggml-backend.cpp index 634fe38ee..9a6010d36 100644 --- a/src/ggml-backend.cpp +++ b/src/ggml-backend.cpp @@ -1538,12 +1538,13 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * ggml_backend_sched_split_graph(sched, measure_graph); + ggml_backend_sched_synchronize(sched); + if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { return false; } ggml_backend_sched_reset(sched); - ggml_backend_sched_synchronize(sched); return true; } From 35b66a3d8956d1f9dbef35106b09f19d98bb9454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 17 Nov 2024 09:06:34 +0100 Subject: [PATCH 09/17] CMake: default to -arch=native for CUDA build (llama/10320) --- src/ggml-cuda/CMakeLists.txt | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/ggml-cuda/CMakeLists.txt b/src/ggml-cuda/CMakeLists.txt index 40ed2bdf3..860552f3a 100644 --- a/src/ggml-cuda/CMakeLists.txt +++ b/src/ggml-cuda/CMakeLists.txt @@ -6,15 +6,18 @@ if (CUDAToolkit_FOUND) message(STATUS "CUDA Toolkit found") if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - # 52 == lowest CUDA 12 standard - # 60 == FP16 CUDA intrinsics - # 61 == integer CUDA intrinsics - # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) + # native == GPUs available at build time + # 52 == Maxwell, lowest CUDA 12 standard + # 60 == P100, FP16 CUDA intrinsics + # 61 == Pascal, __dp4a instruction (per-byte integer dot product) + # 70 == V100, FP16 tensor cores + # 75 == Turing, int6 tensor cores + if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6") + set(CMAKE_CUDA_ARCHITECTURES "native") + elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") else() set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") - #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") From 1f57528b45f440fd8a5264ff7e23fee86a19402f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 17 Nov 2024 09:09:55 +0100 Subject: [PATCH 10/17] CUDA: remove DMMV, consolidate F16 mult mat vec (llama/10318) --- CMakeLists.txt | 5 - src/ggml-cuda/CMakeLists.txt | 13 -- src/ggml-cuda/ggml-cuda.cu | 204 +++----------------------------- src/ggml-cuda/mmv.cu | 223 +++++++++++++++++++++++++++++++++++ src/ggml-cuda/mmv.cuh | 12 ++ src/ggml-hip/CMakeLists.txt | 7 -- src/ggml-musa/CMakeLists.txt | 11 -- 7 files changed, 254 insertions(+), 221 deletions(-) create mode 100644 src/ggml-cuda/mmv.cu create mode 100644 src/ggml-cuda/mmv.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index fd9499826..a82818d60 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -128,14 +128,9 @@ option(GGML_LLAMAFILE "ggml: use LLAMAFILE" option(GGML_CUDA "ggml: use CUDA" OFF) option(GGML_MUSA "ggml: use MUSA" OFF) -option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) -set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") -set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) -set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING - "ggml: iters./thread per block for Q2_K/Q6_K") set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) diff --git a/src/ggml-cuda/CMakeLists.txt b/src/ggml-cuda/CMakeLists.txt index 860552f3a..3dde0f366 100644 --- a/src/ggml-cuda/CMakeLists.txt +++ b/src/ggml-cuda/CMakeLists.txt @@ -54,21 +54,12 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE ggml-base) target_include_directories(ggml-cuda PRIVATE . ..) - # TODO: change the definitions to this target only - - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS) endif() - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -81,10 +72,6 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu index 07f043328..ef56e944d 100644 --- a/src/ggml-cuda/ggml-cuda.cu +++ b/src/ggml-cuda/ggml-cuda.cu @@ -16,11 +16,11 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" -#include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" +#include "ggml-cuda/mmv.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" @@ -1020,114 +1020,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)( #define MUL_MAT_SRC1_COL_STRIDE 128 -static __global__ void mul_mat_p021_f16_f32( - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / (nchannels_y / nchannels_x); - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - // y is not transposed but permuted - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // dst is not transposed and not permuted - const int idst = channel*nrows_dst + row_dst; - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / channel_x_divisor; - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - const int idst = channel*nrows_dst + row_dst; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - const int row_y = col_x; - - const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const int iy = channel*nrows_y + row_y; - - const float xi = __half2float(x[ix]); - - tmp += xi * y[iy]; - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static void ggml_mul_mat_p021_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, - const int nchannels_x, const int nchannels_y, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); -} - -static void ggml_mul_mat_vec_nc_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); -} - static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { @@ -1654,58 +1546,6 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation - GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); -} - -static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - const int64_t row_stride_x = nb01 / sizeof(half); - const int64_t channel_stride_x = nb02 / sizeof(half); - - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); -} - static __global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, char * dst, const void ** ptrs_src, void ** ptrs_dst, @@ -1879,21 +1719,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type) + bool use_mul_mat_vec = src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) + bool use_mul_mat_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - // if mmvq is available it's a better choice than dmmv: -#ifndef GGML_CUDA_FORCE_DMMV - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; -#endif // GGML_CUDA_FORCE_DMMV - - bool any_gpus_with_slow_fp16 = false; + bool any_gpus_with_slow_fp16 = false; + bool any_gpus_without_fp16_mma = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; @@ -1904,14 +1740,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - const int cc = ggml_cuda_info().devices[id].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } } else { - const int cc = ggml_cuda_info().devices[ctx.device].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } // debug helpers @@ -1922,18 +1760,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // FP32 precision KQ single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // FP32 precision KQV single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); + if (!split && src0->type == GGML_TYPE_F16 && src1->ne[1] == 1 && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); - } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); + } else if (use_mul_mat_vec) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { diff --git a/src/ggml-cuda/mmv.cu b/src/ggml-cuda/mmv.cu new file mode 100644 index 000000000..cfe91f428 --- /dev/null +++ b/src/ggml-cuda/mmv.cu @@ -0,0 +1,223 @@ +#include "common.cuh" +#include "mmv.cuh" + +template +static __global__ void mul_mat_vec( + const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { + const int64_t row = blockIdx.x; + const int64_t channel = blockIdx.z; + const int tid = threadIdx.x; + + x += (channel/channel_ratio)*stride_channel_x + row*stride_row; + y += channel *stride_channel_y; + dst += channel *stride_channel_dst; + + const half2 * x2 = (const half2 *) x; + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + float * buf_iw = (float *) data_mmv; + + if (block_size > WARP_SIZE) { + if (tid < WARP_SIZE) { + buf_iw[tid] = 0.0f; + } + __syncthreads(); + } + + float sumf; + + if (std::is_same::value) { + sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + const float2 tmpy = y2[col2]; + sumf += tmpx.x * tmpy.x; + sumf += tmpx.y * tmpy.y; + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2 = make_half2(0.0f, 0.0f); + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmp = y2[col2]; + sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); + } + + sumf = __low2float(sumh2) + __high2float(sumh2); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + + sumf = warp_reduce_sum(sumf); + + if (block_size > WARP_SIZE) { + buf_iw[tid/WARP_SIZE] = sumf; + __syncthreads(); + if (tid > WARP_SIZE) { + return; + } + sumf = buf_iw[tid]; + sumf = warp_reduce_sum(sumf); + } + + if (tid != 0) { + return; + } + + dst[row] = sumf; +} + +template +static void launch_mul_mat_vec_cuda( + const half * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(nchannels_y % nchannels_x == 0); + const int64_t channel_ratio = nchannels_y / nchannels_x; + + int64_t block_size_best = WARP_SIZE; + int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const int smem = WARP_SIZE*sizeof(float); + const dim3 block_nums(nrows, 1, nchannels_y); + const dim3 block_dims(block_size_best, 1, 1); + switch (block_size_best) { + case 32: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 64: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 96: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 128: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 160: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 192: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 224: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 256: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void mul_mat_vec_cuda( + const half * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + enum ggml_prec prec, cudaStream_t stream) { + switch (prec) { + case GGML_PREC_DEFAULT: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + case GGML_PREC_F32: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + } +} + +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + GGML_ASSERT(src1->ne[1] == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const half * src0_d = (const half *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + GGML_ASSERT(dst->ne[2] == ne12); + + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT( dst->ne[3] == 1); + + const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type); + const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type); + const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); + const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); + + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); +} + +void ggml_cuda_op_mul_mat_vec( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + GGML_ASSERT(src1_ncols == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + + // ggml_cuda_op provides single, contiguous matrices + const int64_t stride_row = ne00; + const int64_t nchannels_x = 1; + const int64_t nchannels_y = 1; + const int64_t channel_stride_x = 0; + const int64_t channel_stride_y = 0; + const int64_t channel_stride_dst = 0; + + mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + + GGML_UNUSED(ctx); + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); +} diff --git a/src/ggml-cuda/mmv.cuh b/src/ggml-cuda/mmv.cuh new file mode 100644 index 000000000..78a1cd4a6 --- /dev/null +++ b/src/ggml-cuda/mmv.cuh @@ -0,0 +1,12 @@ +#include "common.cuh" + +// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available +#define MMV_MAX_ROWS 512 + +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +void ggml_cuda_op_mul_mat_vec( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/src/ggml-hip/CMakeLists.txt b/src/ggml-hip/CMakeLists.txt index 5ed186ded..fccf8eb84 100644 --- a/src/ggml-hip/CMakeLists.txt +++ b/src/ggml-hip/CMakeLists.txt @@ -75,18 +75,11 @@ target_include_directories(ggml-hip PRIVATE . ..) target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) add_compile_definitions(GGML_USE_HIP) -add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) -add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) -add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) if (GGML_HIP_UMA) add_compile_definitions(GGML_HIP_UMA) endif() -if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) -endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() diff --git a/src/ggml-musa/CMakeLists.txt b/src/ggml-musa/CMakeLists.txt index 8edc75cc5..f3c013692 100644 --- a/src/ggml-musa/CMakeLists.txt +++ b/src/ggml-musa/CMakeLists.txt @@ -58,19 +58,12 @@ if (MUSAToolkit_FOUND) target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) add_compile_definitions(GGML_USE_MUSA) - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS) endif() - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -83,10 +76,6 @@ if (MUSAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() From 492779b60a78a2d9265f0ea42fc75747aab1b2e2 Mon Sep 17 00:00:00 2001 From: FirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com> Date: Sun, 17 Nov 2024 21:39:22 +1300 Subject: [PATCH 11/17] ggml : fix undefined reference to 'getcpu' (llama/10354) https://github.com/ggerganov/llama.cpp/issues/10352 --- src/ggml-cpu/ggml-cpu.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ggml-cpu/ggml-cpu.c b/src/ggml-cpu/ggml-cpu.c index df6487929..0d23669c2 100644 --- a/src/ggml-cpu/ggml-cpu.c +++ b/src/ggml-cpu/ggml-cpu.c @@ -2369,7 +2369,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) { // figure out which node we're on uint current_cpu; int getcpu_ret = 0; -#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__) +#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__) getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); #else // old glibc doesn't have a wrapper for this call. Fall back on direct syscall From 7bef2570e084a74bf815ffac5d40cd9f0d18dfae Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Nov 2024 11:23:01 +0200 Subject: [PATCH 12/17] metal : refactor kernel args into structs (llama/10238) * metal : add kernel arg structs (wip) * metal : fattn args ggml-ci * metal : cont + avoid potential int overflow [no ci] * metal : mul mat struct (wip) * cont : mul mat vec * cont : pass by reference * cont : args is first argument * cont : use char ptr * cont : shmem style * cont : thread counters style * cont : mul mm id ggml-ci * cont : int safety + register optimizations ggml-ci * metal : GGML_OP_CONCAT ggml-ci * metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV * metal : GGML_OP_REPEAT * metal : GGML_OP_CPY * metal : GGML_OP_RMS_NORM * metal : GGML_OP_NORM * metal : add TODOs for rest of ops * ggml : add ggml-metal-impl.h ggml-ci --- src/ggml-metal/CMakeLists.txt | 18 +- src/ggml-metal/ggml-metal-impl.h | 249 +++ src/ggml-metal/ggml-metal.m | 681 +++--- src/ggml-metal/ggml-metal.metal | 3592 +++++++++++------------------- 4 files changed, 1945 insertions(+), 2595 deletions(-) create mode 100644 src/ggml-metal/ggml-metal-impl.h diff --git a/src/ggml-metal/CMakeLists.txt b/src/ggml-metal/CMakeLists.txt index e0992c744..b237d79f4 100644 --- a/src/ggml-metal/CMakeLists.txt +++ b/src/ggml-metal/CMakeLists.txt @@ -25,9 +25,10 @@ if (GGML_METAL_USE_BF16) add_compile_definitions(GGML_METAL_USE_BF16) endif() -# copy ggml-common.h and ggml-metal.metal to bin directory -configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) -configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +# copy metal files to bin directory +configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) +configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -36,24 +37,27 @@ if (GGML_METAL_EMBED_LIBRARY) set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT ${METALLIB_EMBED_ASM} COMMAND echo "Embedding Metal library" - COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED} + COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP} + COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED} COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} - DEPENDS ggml-metal.metal ../ggml-common.h + DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h COMMENT "Generate assembly for embedded Metal library" ) diff --git a/src/ggml-metal/ggml-metal-impl.h b/src/ggml-metal/ggml-metal-impl.h new file mode 100644 index 000000000..53c135496 --- /dev/null +++ b/src/ggml-metal/ggml-metal-impl.h @@ -0,0 +1,249 @@ +#ifndef GGML_METAL_IMPL +#define GGML_METAL_IMPL + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; + +#endif // GGML_METAL_IMPL diff --git a/src/ggml-metal/ggml-metal.m b/src/ggml-metal/ggml-metal.m index 95b21fbf9..58fee4bfd 100644 --- a/src/ggml-metal/ggml-metal.m +++ b/src/ggml-metal/ggml-metal.m @@ -2,6 +2,7 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" +#import "ggml-metal-impl.h" #import @@ -1193,35 +1194,39 @@ static void ggml_metal_encode_node( const int32_t dim = ((const int32_t *) dst->op_params)[0]; + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN(1024, ne0); @@ -1239,8 +1244,6 @@ static void ggml_metal_encode_node( bool bcast_row = false; - int64_t nb = ne00; // used by the "row" kernels - id pipeline = nil; if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { @@ -1249,7 +1252,6 @@ static void ggml_metal_encode_node( // src1 is a row GGML_ASSERT(ne11 == 1); - nb = ne00 / 4; switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; @@ -1269,36 +1271,39 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (bcast_row) { const int64_t n = ggml_nelements(dst)/4; @@ -1322,25 +1327,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); } + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); @@ -1369,25 +1378,29 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -1396,35 +1409,39 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -1465,10 +1482,10 @@ static void ggml_metal_encode_node( memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; const int64_t n = ggml_nelements(dst); @@ -1640,6 +1657,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1715,6 +1733,8 @@ static void ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO: add ggml_metal_kargs struct + // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { @@ -1731,6 +1751,7 @@ static void ggml_metal_encode_node( [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -1747,6 +1768,7 @@ static void ggml_metal_encode_node( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1771,6 +1793,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1841,6 +1864,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1959,24 +1983,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL MAT-MAT not implemented"); } + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -2154,28 +2183,32 @@ static void ggml_metal_encode_node( } }; + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || @@ -2288,27 +2321,30 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL_MAT_ID not implemented"); } + ggml_metal_kargs_mul_mm_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; @@ -2467,30 +2503,34 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne00 >= nth0*nth1); } + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; const int tgz = dst_rows; @@ -2563,6 +2603,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2586,20 +2627,28 @@ static void ggml_metal_encode_node( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + int nth = 32; // SIMD width - while (nth < ne00/4 && nth < 1024) { + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { nth *= 2; } - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -2624,6 +2673,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2641,22 +2691,35 @@ static void ggml_metal_encode_node( } break; case GGML_OP_NORM: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = MIN(256, ne00); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -2706,40 +2769,44 @@ static void ggml_metal_encode_node( }; } + ggml_metal_kargs_rope args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2796,6 +2863,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2836,6 +2904,7 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2870,6 +2939,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2906,6 +2976,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; @@ -2927,6 +2998,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2965,6 +3037,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2983,6 +3056,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3224,37 +3298,41 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb_12_1 =*/ nb11, + /*.nb_12_2 =*/ nb12, + /*.nb_12_3 =*/ nb13, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19]; - [encoder setBytes:&scale length:sizeof( float) atIndex:20]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:21]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:22]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:23]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel @@ -3389,25 +3467,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -3452,6 +3534,7 @@ static void ggml_metal_encode_node( const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; diff --git a/src/ggml-metal/ggml-metal.metal b/src/ggml-metal/ggml-metal.metal index 8c7fcb113..86fdf1c18 100644 --- a/src/ggml-metal/ggml-metal.metal +++ b/src/ggml-metal/ggml-metal.metal @@ -6,6 +6,7 @@ __embed_ggml-common.h__ // TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift #include "../ggml-common.h" #endif +#include "ggml-metal-impl.h" #include @@ -497,240 +498,131 @@ enum ggml_sort_order { // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_sub( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_mul( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_div( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); } } template kernel void kernel_repeat( + constant ggml_metal_kargs_repeat & args, device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; - const int64_t i03 = i3 % ne03; - const int64_t i02 = i2 % ne02; - const int64_t i01 = i1 % ne01; + const int i03 = i3%args.ne03; + const int i02 = i2%args.ne02; + const int i01 = i1%args.ne01; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i00 = i0 % ne00; - *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i00 = i0%args.ne00; + *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); } } @@ -744,38 +636,42 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] + src1[tpig % nb]; } kernel void kernel_sub_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] - src1[tpig % nb]; } kernel void kernel_mul_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] * src1[tpig % nb]; } kernel void kernel_div_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] / src1[tpig % nb]; } @@ -1345,102 +1241,112 @@ kernel void kernel_ssm_scan_f32( } kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; + constant ggml_metal_kargs_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float4 sumf4(0.0f); + + float sumf = 0.0f; + + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf4 += x[i00]; } - // reduce + sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = simd_sum(sumf); + threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; } - const float mean = sum[0] / ne00; - // recenter and VARIANCE threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + + sumf = 0.0f; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; + sumf += dot(y[i00], y[i00]); } + sumf = simd_sum(sumf); - // reduce threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; } - const float variance = sum[0] / ne00; - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float variance = sumf/args.ne00; + + const float scale = 1.0f/sqrt(variance + args.eps); + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = y[i00] * scale; } } kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + constant ggml_metal_kargs_rms_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); - float4 sumf = 0; - float all_sum = 0; + float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } + sumf = simd_sum(sumf); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - if (tiisg == 0) { - buf[sgitg] = all_sum; - } + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - all_sum = buf[tiisg]; - all_sum = simd_sum(all_sum); - } + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); - const float mean = all_sum/ne00; - const float scale = 1.0f/sqrt(mean + eps); + const float mean = sumf/args.ne00; + const float scale = 1.0f/sqrt(mean + args.eps); - device float4 * y = (device float4 *) (dst + tgpig*ne00); - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = x[i00] * scale; } } @@ -1628,31 +1534,17 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // quantizations where the block size is 32. It also does not // guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - const int nb = ne00/QK4_0; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -1660,19 +1552,19 @@ void mul_vec_q_n_f32_impl( const int first_row = (r0 * nsg + sgitg) * nr; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows device const block_q_type * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -1680,10 +1572,10 @@ void mul_vec_q_n_f32_impl( float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; - device const float * yb = y + ix * QK4_0 + il; + device const float * yb = y + ix*QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { @@ -1708,324 +1600,216 @@ void mul_vec_q_n_f32_impl( yb += QK4_0 * 16; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; } } } kernel void kernel_mul_mv_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } - #define NB_Q8_0 8 +template void kernel_mul_mv_q8_0_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; - const int nb = ne00/QK8_0; + const int nb = args.ne00/QK8_0; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0*nsg + sgitg)*nr; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows device const block_q8_0 * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } float yl[NB_Q8_0]; - float sumf[nr]={0.f}; + float sumf[nr] = { 0.f }; - const int ix = tiisg/4; - const int il = tiisg%4; + const short ix = tiisg/4; + const short il = tiisg%4; - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { + for (short i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { - device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il; + device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { + for (short iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*ax[row][ib].d; } - yb += NB_Q8_0 * nw; + yb += nw*NB_Q8_0; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; } } } [[host_name("kernel_mul_mv_q8_0_f32")]] kernel void kernel_mul_mv_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 -template +template void kernel_mul_mv_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_MV_T_T; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg) { + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); - if (ne00 < 128) { + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; - if (r1 >= ne11) { + if (r1 >= args.ne11) { break; } - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { + for (int i = tiisg; i < args.ne00; i += 32) { sumf += (T0) x[i] * (T1) y[i]; } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } else { device const T04 * x4 = (device const T04 *) x; for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; - if (r1 >= ne11) { + if (r1 >= args.ne11) { break; } - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], (float4) y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -2033,51 +1817,17 @@ void kernel_mul_mv_impl( template kernel void kernel_mul_mv( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_impl( + args, src0, src1, dst, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - nb03, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - nb13, - ne0, - ne1, - r2, - r3, tgpig, tiisg); } @@ -2094,65 +1844,50 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { + if (args.ne00 < 128) { + for (int i = tiisg; i < args.ne00; i += 32) { sumf += (float) x[i] * (float) y[i]; } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } else { device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst_f32[r0] = all_sum; } } } @@ -2167,54 +1902,39 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne // Assumes row size (ne00) is a multiple of 4 template kernel void kernel_mul_mv_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; + const int nrows = args.ne11; + const int r0 = tgpig.x; + const int im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + for (int r1 = 0; r1 < nrows; ++r1) { - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -2234,7 +1954,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale, thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; @@ -2266,65 +1986,41 @@ static void rope_yarn_corr_dims( template kernel void kernel_rope_norm( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); - device const int32_t * pos = src1; + device const int32_t * pos = (device const int32_t *) src1; const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; + const float inv_ndims = -1.f/args.n_dims; float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); const float x0 = src[0]; const float x1 = src[1]; @@ -2332,8 +2028,8 @@ kernel void kernel_rope_norm( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_data[0] = src[0]; dst_data[1] = src[1]; @@ -2343,74 +2039,50 @@ kernel void kernel_rope_norm( template kernel void kernel_rope_neox( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); - device const int32_t * pos = src1; + device const int32_t * pos = (device const int32_t *) src1; const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; + const float inv_ndims = -1.f/args.n_dims; float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); const float x0 = src[0]; - const float x1 = src[n_dims/2]; + const float x1 = src[args.n_dims/2]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_data[0] = src[0]; dst_data[1] = src[1]; @@ -2808,37 +2480,17 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int32_t & ne01, - constant int32_t & ne02, - constant int32_t & ne03, - constant uint32_t & nb01, - constant uint32_t & nb02, - constant uint32_t & nb03, - constant int32_t & ne11, - constant int32_t & ne_12_2, // assume K and V are same shape - constant int32_t & ne_12_3, - constant uint32_t & nb_12_1, - constant uint32_t & nb_12_2, - constant uint32_t & nb_12_3, - constant uint32_t & nb31, - constant int32_t & ne1, - constant int32_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint16_t & n_head_log2, - constant float & logit_softcap, - threadgroup half * shared [[threadgroup(0)]], - ushort3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -2854,27 +2506,27 @@ kernel void kernel_flash_attn_ext( const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short T = D + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix - threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o8x8_t lo[D8]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { + if (iq1 + j < args.ne01) { sq4[j*D4 + i] = (q4_t) q4[i]; } else { sq4[j*D4 + i] = (q4_t) 0.0f; @@ -2907,11 +2559,11 @@ kernel void kernel_flash_attn_ext( const short ty = tiisg/4; // broadcast kv - //const short rk2 = ne02/ne12; - //const short rk3 = ne03/ne13; + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short ikv2 = iq2/(ne02/ne_12_2); - const short ikv3 = iq3/(ne03/ne_12_3); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); // load the queries from shared memory into local memory q8x8_t mq[D8]; @@ -2925,20 +2577,20 @@ kernel void kernel_flash_attn_ext( half slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < n_head_log2 ? m0 : m1; - const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } @@ -2949,7 +2601,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); const half m = pm[ic + tiisg]; @@ -2972,18 +2624,18 @@ kernel void kernel_flash_attn_ext( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D8) for (short i = 0; i < D8; ++i) { k8x8_t mk; - simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); if (D16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -3042,10 +2694,10 @@ kernel void kernel_flash_attn_ext( const half m = M[j]; // scale and apply the logitcap / mask - half s = ss[j*TS + tiisg]*scale; + half s = ss[j*TS + tiisg]*args.scale; - if (logit_softcap != 0.0f) { - s = logit_softcap*precise::tanh(s); + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); } // mqk = mqk + mask*slope @@ -3087,18 +2739,18 @@ kernel void kernel_flash_attn_ext( if (is_same::value) { // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D8) for (short i = 0; i < D8; ++i) { v8x8_t mv; - simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 + simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); if (D16%4 == 0) { // no need for bound checks @@ -3227,11 +2879,11 @@ kernel void kernel_flash_attn_ext( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3323,38 +2975,17 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int32_t & ne01, - constant int32_t & ne02, - constant int32_t & ne03, - constant uint32_t & nb01, - constant uint32_t & nb02, - constant uint32_t & nb03, - constant int32_t & ne11, - constant int32_t & ne_12_2, // assume K and V are same shape - constant int32_t & ne_12_3, - constant uint32_t & nb_12_1, - constant uint32_t & nb_12_2, - constant uint32_t & nb_12_3, - constant uint32_t & nb31, - constant int32_t & ne1, - constant int32_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint16_t & n_head_log2, - constant float & logit_softcap, - threadgroup half * shared [[threadgroup(0)]], - ushort3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -3369,22 +3000,22 @@ kernel void kernel_flash_attn_ext_vec( const short T = D + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask - threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask + threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o4x4_t lo[D16/NL]; // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { + if (iq1 < args.ne01) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; @@ -3412,11 +3043,11 @@ kernel void kernel_flash_attn_ext_vec( const short ty = tiisg/NL; // broadcast kv - //const short rk2 = ne02/ne12; - //const short rk3 = ne03/ne13; + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short ikv2 = iq2/(ne02/ne_12_2); - const short ikv3 = iq3/(ne03/ne_12_3); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); // load the queries from shared memory into local memory q4x4_t mq[D16/NL]; @@ -3429,25 +3060,25 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31); half slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < n_head_log2 ? m0 : m1; - const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } @@ -3461,7 +3092,7 @@ kernel void kernel_flash_attn_ext_vec( for (short cc = 0; cc < C/4; ++cc) { qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; - device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D16/NL) for (short ii = 0; ii < D16; ii += NL) { @@ -3497,10 +3128,10 @@ kernel void kernel_flash_attn_ext_vec( // mqk = mqk*scale + mask*slope if (tx == 0) { - mqk *= scale; + mqk *= args.scale; - if (logit_softcap != 0.0f) { - mqk = logit_softcap*precise::tanh(mqk); + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); } mqk += sm[4*cc + ty]*slope; @@ -3539,7 +3170,7 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/4; ++cc) { - device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); const s4x4_t ms(ss[4*cc + ty]); @@ -3644,7 +3275,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } @@ -3686,42 +3317,27 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ template kernel void kernel_cpy( - device const void * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); - device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { + device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; } } @@ -3741,42 +3357,27 @@ template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy 0 ? sumqx/sumq2 : d; - } } kernel void kernel_concat( + constant ggml_metal_kargs_concat & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & dim, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + int o[4] = {0, 0, 0, 0}; + o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); device const float * x; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); } else { - x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); } - device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); *y = *x; } } +template void kernel_mul_mv_q2_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4305,92 +3793,64 @@ void kernel_mul_mv_q2_K_f32_impl( (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - qs += nb01/2; - sc += nb01; - dh += nb01/2; + qs += args.nb01/2; + sc += args.nb01; + dh += args.nb01/2; } y4 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } [[host_name("kernel_mul_mv_q2_K_f32")]] kernel void kernel_mul_mv_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float yl[32]; @@ -4420,9 +3880,10 @@ void kernel_mul_mv_q3_K_f32_impl( const ushort4 hm = mm[2*ip + il/2]; - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; @@ -4491,10 +3952,10 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] += d1 * (scales[1] - 32); sumf2[row] += d2 * (scales[3] - 32); - q += nb01/2; - h += nb01/2; - a += nb01/2; - dh += nb01/2; + q += args.nb01/2; + h += args.nb01/2; + a += args.nb01/2; + dh += args.nb01/2; } y1 += 4 * QK_K; @@ -4504,66 +3965,39 @@ void kernel_mul_mv_q3_K_f32_impl( const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); sumf1[row] = simd_sum(sumf); } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + if (tiisg == 0) { for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; + dst_f32[first_row + row] = sumf1[row]; } } } [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -4574,21 +4008,21 @@ void kernel_mul_mv_q4_K_f32_impl( const int iq = it/4; // 0 or 1 const int ir = it%4; // 0...3 - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = r0 * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[16]; float yh[16]; @@ -4641,92 +4075,64 @@ void kernel_mul_mv_q4_K_f32_impl( (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += nb01/2; - sc += nb01/2; - dh += nb01/2; + q1 += args.nb01/2; + sc += args.nb01/2; + dh += args.nb01/2; } y4 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q5_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float sumf[2]={0.f}; @@ -4800,98 +4206,70 @@ void kernel_mul_mv_q5_K_f32_impl( sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += nb01; - qh += nb01; - dh += nb01/2; - a += nb01/2; + q1 += args.nb01; + qh += args.nb01; + dh += args.nb01/2; + a += args.nb01/2; } y1 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + dst_f32[first_row + row] = tot; } } } [[host_name("kernel_mul_mv_q5_K_f32")]] kernel void kernel_mul_mv_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q6_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; const uint8_t kmask3 = 0x30; const uint8_t kmask4 = 0xC0; - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const int row = 2 * r0 + sgitg; + const int row = 2*r0 + sgitg; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float sumf = 0; @@ -4908,7 +4286,6 @@ void kernel_mul_mv_q6_K_f32_impl( const int q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { - device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; @@ -4930,98 +4307,70 @@ void kernel_mul_mv_q6_K_f32_impl( } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; + dst_f32[row] = tot; } } [[host_name("kernel_mul_mv_q6_K_f32")]] kernel void kernel_mul_mv_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit +template void kernel_mul_mv_iq2_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5051,114 +4400,85 @@ void kernel_mul_mv_iq2_xxs_f32_impl( float sum = 0; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } sumf[row] += d * sum; - dh += nb01/2; - q2 += nb01/2; + dh += args.nb01/2; + q2 += args.nb01/2; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } [[host_name("kernel_mul_mv_iq2_xxs_f32")]] kernel void kernel_mul_mv_iq2_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512); { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5190,122 +4510,94 @@ void kernel_mul_mv_iq2_xs_f32_impl( float sum1 = 0, sum2 = 0; for (int l = 0; l < 2; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } for (int l = 2; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } sumf[row] += d1 * sum1 + d2 * sum2; - dh += nb01/2; - q2 += nb01/2; - sc += nb01; + dh += args.nb01/2; + q2 += args.nb01/2; + sc += args.nb01; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } - } -} - -[[host_name("kernel_mul_mv_iq2_xs_f32")]] -kernel void kernel_mul_mv_iq2_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5314,7 +4606,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5328,16 +4619,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); @@ -5345,103 +4635,75 @@ void kernel_mul_mv_iq3_xxs_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb01/2; - q3 += nb01; - gas += nb01/2; + dh += args.nb01/2; + q3 += args.nb01; + gas += args.nb01/2; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + dst_f32[first_row + row] = all_sum * 0.5f; } } } [[host_name("kernel_mul_mv_iq3_xxs_f32")]] kernel void kernel_mul_mv_iq3_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem; { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5472,8 +4734,8 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; - const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); for (int j = 0; j < 4; ++j) { @@ -5483,105 +4745,77 @@ void kernel_mul_mv_iq3_s_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb01/2; - qs += nb01; - qh += nb01; - sc += nb01; - signs += nb01; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } [[host_name("kernel_mul_mv_iq3_s_f32")]] kernel void kernel_mul_mv_iq3_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem; //{ // int nval = 32; // int pos = (32*sgitg + tiisg)*nval; - // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i]; // threadgroup_barrier(mem_flags::mem_threadgroup); //} @@ -5613,8 +4847,8 @@ void kernel_mul_mv_iq2_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 2; ++l) { - //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); for (int j = 0; j < 8; ++j) { @@ -5624,94 +4858,66 @@ void kernel_mul_mv_iq2_s_f32_impl( } sumf[row] += d1 * sum[0] + d2 * sum[1]; - dh += nb01/2; - qs += nb01; - qh += nb01; - sc += nb01; - signs += nb01; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } [[host_name("kernel_mul_mv_iq2_s_f32")]] kernel void kernel_mul_mv_iq2_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq1_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5754,61 +4960,50 @@ void kernel_mul_mv_iq1_s_f32_impl( } sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - dh += nb01/2; - qs += nb01; - qh += nb01/2; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01/2; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } +template void kernel_mul_mv_iq1_m_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5860,66 +5055,55 @@ void kernel_mul_mv_iq1_m_f32_impl( sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); - sc += nb01/2; - qs += nb01; - qh += nb01; + sc += args.nb01/2; + qs += args.nb01; + qh += args.nb01; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } +template void kernel_mul_mv_iq4_nl_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK4_NL; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -5937,7 +5121,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5947,16 +5131,16 @@ void kernel_mul_mv_iq4_nl_f32_impl( aux32[0] = q4[0] | (q4[1] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[2] | (q4[3] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -5969,60 +5153,49 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } +template void kernel_mul_mv_iq4_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK_K; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 const int ib = it/2; const int il = it%2; - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -6036,28 +5209,26 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; for (int ibl = ix; ibl < nb; ibl += 2) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; for (int row = 0; row < 2; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; - aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[0] = (q4[0] ) & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; - aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[0] = (q4[1] ) & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -6071,134 +5242,68 @@ void kernel_mul_mv_iq4_xs_f32_impl( yb += 2 * QK_K; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] kernel void kernel_mul_mv_iq1_m_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6302,38 +5407,26 @@ kernel void kernel_get_rows_i32( // each block_q contains 16*nl weights template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup T * sa = (threadgroup T *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; +kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); + + const int r0 = tgpig.y; + const int r1 = tgpig.x; + const int im = tgpig.z; // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; @@ -6349,20 +5442,20 @@ kernel void kernel_mul_mm(device const uchar * src0, short il = (tiitg % THREAD_PER_ROW); - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; - uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; - ushort offset1 = il/nl; + uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + short offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1; + device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 - + nb13 * i13 - + nb12 * i12 - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1 * BLOCK_SIZE_N + thread_col) + + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory T4x4 temp_a; dequantize_func(x, il, temp_a); @@ -6409,16 +5502,18 @@ kernel void kernel_mul_mm(device const uchar * src0, } } - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0); } } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); @@ -6428,7 +5523,7 @@ kernel void kernel_mul_mm(device const uchar * src0, if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0; + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; device float4 * D4 = (device float4 *) D; threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); @@ -6449,36 +5544,37 @@ kernel void kernel_mul_mm(device const uchar * src0, } // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +// TODO: this kernel needs to be reimplemented from scratch for better performance template void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, + int32_t ne00, + int32_t ne02, + uint64_t nb01, + uint64_t nb02, + int32_t ne11, + int32_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int32_t ne0, + int32_t ne1, + int64_t ne0ne1, + device const char * src0, + device const char * src1, threadgroup ushort2 * rowids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - int64_t ne0ne1, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - - if (r1 * BLOCK_SIZE_N >= ne1) return; + device char * dst, + threadgroup char * shmem, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); + + const int r0 = tgpig.y; + const int r1 = tgpig.x; + + if (r1*BLOCK_SIZE_N >= ne1) return; // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; @@ -6490,9 +5586,9 @@ void kernel_mul_mm_id_impl( simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; + simdgroup_float8x8 mc[8]; for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); + mc[i] = make_filled_simdgroup_matrix(0.f); } short il = (tiitg % THREAD_PER_ROW); @@ -6530,11 +5626,14 @@ void kernel_mul_mm_id_impl( threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(BLOCK_SIZE_K/8) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) for (int i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) for (int i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } @@ -6542,29 +5641,42 @@ void kernel_mul_mm_id_impl( lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } } } { threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int joff = jid[0] * ne0 + jid[1] * ne0ne1; - for (int i = 0; i < n_rows; i++) { - *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); + int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; + + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); } } } @@ -6573,48 +5685,34 @@ void kernel_mul_mm_id_impl( template kernel void kernel_mul_mm_id( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int32_t i02 = tgpig.z; + tgpig.z = 0; - device const uchar * src0 = src0s + i02*nb02; + device const char * src0 = src0s + i02*args.nb02; // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop - int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + int32_t _ne1 = 0; + for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { + for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; if (id == i02) { - //if (tiitg == 0) { + if (tiitg == 0) { rowids[_ne1] = ushort2(ii0, ii1); - //} + } _ne1++; } } @@ -6623,23 +5721,23 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); kernel_mul_mm_id_impl( + args.ne00, + args.ne02, + args.nb01, + args.nb02, + args.ne11, + args.ne12, + args.nb10, + args.nb11, + args.nb12, + args.ne0, + _ne1, + (int64_t)args.ne0*args.ne1, src0, src1, rowids, dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, + shmem, tgpig, tiitg, sgitg); @@ -6748,194 +5846,110 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel // typedef void (kernel_mul_mv_impl_t)( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg); + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg); typedef void (kernel_mul_mv2_impl_t)( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg); + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg); template void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg); + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { + impl_fn(args, src0, src1, dst, tgpig, tiisg); } template void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); -} - -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { + impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int iid1 = tgpig.z/nei0; - const int idx = tgpig.z%nei0; + constant ggml_metal_kargs_mul_mv_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const int iid1 = tgpig.z/args.nei0; + const int idx = tgpig.z%args.nei0; tgpig.z = 0; - const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; + const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx]; - const int64_t i11 = idx % ne11; + const int64_t i11 = idx % args.ne11; const int64_t i12 = iid1; const int64_t i1 = idx; const int64_t i2 = i12; - device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + device const char * src0_cur = src0s + i02*args.nb02; + device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; + + device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float); + + ggml_metal_kargs_mul_mv args0 = { + /*.ne00 =*/ args.ne00, + /*.ne01 =*/ args.ne01, + /*.ne02 =*/ 1, // args.ne02, + /*.nb00 =*/ args.nb00, + /*.nb01 =*/ args.nb01, + /*.nb02 =*/ args.nb02, + /*.nb03 =*/ args.nb02, // args.ne02 == 1 + /*.ne10 =*/ args.ne10, + /*.ne11 =*/ 1, // args.ne11, + /*.ne12 =*/ 1, // args.ne12, + /*.nb10 =*/ args.nb10, + /*.nb11 =*/ args.nb11, + /*.nb12 =*/ args.nb12, + /*.nb13 =*/ args.nb12, // ne12 == 1 + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ 1, // args.ne1, + /*.r2 =*/ 1, + /*.r3 =*/ 1, + }; impl_fn( + args0, /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - /* ne00 */ ne00, - /* ne01 */ ne01, - /* ne02 */ 1, // ne02, - /* nb00 */ nb00, - /* nb01 */ nb01, - /* nb02 */ nb02, - /* nb03 */ nb02, // ne02 == 1 - /* ne10 */ ne10, - /* ne11 */ 1, // ne11, - /* ne12 */ 1, // ne12, - /* ne13 */ 1, // ne13, - /* nb10 */ nb10, - /* nb11 */ nb11, - /* nb12 */ nb12, - /* ne13 */ nb12, // ne12 == 1 - /* ne0 */ ne0, - /* ne1 */ 1, // ne1, - /* nb1 */ nb1, - /* r2 */ 1, - /* r3 */ 1, - shared_values, + shmem, tgpig, tiitg, tiisg, From ffd430e6ced424f544dec3c6d8b8e0c1d0874946 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Sun, 17 Nov 2024 12:25:45 +0100 Subject: [PATCH 13/17] llama : only use default buffer types for the KV cache (llama/10358) --- src/ggml-backend.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/ggml-backend.cpp b/src/ggml-backend.cpp index 9a6010d36..9dcde8d11 100644 --- a/src/ggml-backend.cpp +++ b/src/ggml-backend.cpp @@ -689,7 +689,7 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen } static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) { - ggml_backend_buffer_t buffer = tensor->buffer; + ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (buffer == NULL) { return -1; } @@ -722,8 +722,6 @@ static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML // returns the backend that should be used for the node based on the current locations static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) { - // TODO: use supports_op to check if the backend supports the op - // assign pre-allocated nodes to their backend int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor); if (cur_backend_id != -1) { @@ -742,7 +740,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) { // since the tensor is pre-allocated, it cannot be moved to another backend - GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation"); + GGML_ABORT("pre-allocated tensor (%s) in a backend that cannot run the operation", tensor->name); } // graph input @@ -886,6 +884,9 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; int * node_backend_id = &tensor_backend_id(node); + if (ggml_is_view_op(node->op)) { + continue; + } // do not overwrite user assignments if (*node_backend_id == -1) { *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node); From 8560efb332440ebd66da7451c749e1f99ac85865 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 17 Nov 2024 12:59:38 +0100 Subject: [PATCH 14/17] CMake: fix typo in comment [no ci] (llama/10360) --- src/ggml-cuda/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ggml-cuda/CMakeLists.txt b/src/ggml-cuda/CMakeLists.txt index 3dde0f366..e592f7989 100644 --- a/src/ggml-cuda/CMakeLists.txt +++ b/src/ggml-cuda/CMakeLists.txt @@ -11,7 +11,7 @@ if (CUDAToolkit_FOUND) # 60 == P100, FP16 CUDA intrinsics # 61 == Pascal, __dp4a instruction (per-byte integer dot product) # 70 == V100, FP16 tensor cores - # 75 == Turing, int6 tensor cores + # 75 == Turing, int8 tensor cores if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6") set(CMAKE_CUDA_ARCHITECTURES "native") elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) From d7e136781627b4c28eed048a08eb9f06b53dee3b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Nov 2024 14:02:10 +0200 Subject: [PATCH 15/17] sync : llama.cpp ggml-ci --- scripts/sync-llama.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-llama.last b/scripts/sync-llama.last index 6984a37d5..64fe8ba30 100644 --- a/scripts/sync-llama.last +++ b/scripts/sync-llama.last @@ -1 +1 @@ -883d206fbd2c5b2b9b589a9328503b9005e146c9 +ce2e59ba107cf71ed566040ff20a15d1c58e09c2 From fa00172685beb145057917fa9961e9e151c29a70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 17 Nov 2024 23:20:42 +0100 Subject: [PATCH 16/17] CUDA: fix MMV kernel being used for FP16 src1 (llama/10357) --- src/ggml-cuda/ggml-cuda.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu index ef56e944d..dd94ab03d 100644 --- a/src/ggml-cuda/ggml-cuda.cu +++ b/src/ggml-cuda/ggml-cuda.cu @@ -1760,11 +1760,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && src0->type == GGML_TYPE_F16 && src1->ne[1] == 1 && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + // the custom F16 vector kernel can be used over batched cuBLAS GEMM + // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { - // KQ + KQV multi-batch without FlashAttention + // general KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_mul_mat_vec) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); From 178ebfcc5f125085d51e0953b2d8230c21358650 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 18 Nov 2024 10:14:04 +0200 Subject: [PATCH 17/17] sync : llama.cpp ggml-ci --- scripts/sync-llama.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-llama.last b/scripts/sync-llama.last index 64fe8ba30..c84ddd384 100644 --- a/scripts/sync-llama.last +++ b/scripts/sync-llama.last @@ -1 +1 @@ -ce2e59ba107cf71ed566040ff20a15d1c58e09c2 +75207b3a887f91f813de1eb6e9fd135d3cb2b8c6