Skip to content

Commit

Permalink
metal : fix support check
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Nov 4, 2024
1 parent 41b47e5 commit 82a7012
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_FLASH_ATTN_EXT:
if (op->src[1]->type != op->src[2]->type) {
return false;
}
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
Expand Down Expand Up @@ -2886,6 +2889,7 @@ static void ggml_metal_encode_node(
GGML_ASSERT(ne11 % 32 == 0);

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == src2->type);

GGML_ASSERT(ggml_are_same_shape (src1, src2));

Expand Down Expand Up @@ -3158,7 +3162,7 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
// half1x4 kernel
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!

Expand Down

0 comments on commit 82a7012

Please sign in to comment.