Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metal : fix build and some more comments #10229

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -3046,6 +3046,8 @@ static void ggml_metal_encode_node(

bool use_vec_kernel = false;

// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
// for now avoiding mainly to keep the number of templates/kernels a bit lower
if (ne01 >= 4 || (ne00%128 != 0)) {
switch (src1->type) {
case GGML_TYPE_F16:
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
const short D4 = D/4;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short NL = NW/4;
const short SH = 2*C; // shared memory per simdgroup
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
const short SH = 2*C; // shared memory per simdgroup

const short T = D + nsg*SH; // shared memory size per query in (half)

Expand Down Expand Up @@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(

// Q*K^T
{
// each simdgroup processes 1 query and 4 keys
// each simdgroup processes 1 query and 4 (NW/NL) keys
for (short cc = 0; cc < C/4; ++cc) {
qk_t mqk = 0.0;

Expand Down Expand Up @@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext_vec(
half, half4, half4x4, \
half4x4

typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;

template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
#if defined(GGML_METAL_USE_BF16)
Expand Down
Loading