Skip to content

Commit

Permalink
metal : add GGML_METAL_FORCE_FATTN_PREC_F16
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Nov 6, 2024
1 parent d0cff71 commit a797e5d
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 87 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,11 @@ endif # GGML_HIPBLAS

ifdef GGML_METAL
MK_CPPFLAGS += -DGGML_USE_METAL

ifdef GGML_METAL_FORCE_FATTN_PREC_F16
MK_CPPFLAGS += -DGGML_METAL_FORCE_FATTN_PREC_F16
endif # GGML_METAL_FORCE_FATTN_PREC_F16

MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
OBJ_GGML += ggml/src/ggml-metal.o
ifdef GGML_METAL_NDEBUG
Expand Down
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_FORCE_FATTN_PREC_F16 "ggml: force F16 accumulators for FA kernels" OFF)
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ if (GGML_METAL)
add_compile_definitions(GGML_METAL_NDEBUG)
endif()

if (GGML_METAL_FORCE_FATTN_PREC_F16)
add_compile_definitions(GGML_METAL_FORCE_FATTN_PREC_F16)
endif()

# copy ggml-common.h and ggml-metal.metal to bin directory
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
Expand Down
25 changes: 14 additions & 11 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))

// TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels
#define GGML_METAL_FORCE_FATTN_PREC_F32

// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 64

Expand Down Expand Up @@ -483,9 +480,8 @@ @implementation GGMLMetalClass
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];

// add GGML_METAL_FORCE_FATTN_PREC_F32
#if defined(GGML_METAL_FORCE_FATTN_PREC_F32)
[prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"];
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
[prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F16"];
#endif

MTLCompileOptions* options = [MTLCompileOptions new];
Expand Down Expand Up @@ -538,9 +534,14 @@ @implementation GGMLMetalClass
}
}

GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = yes\n", __func__);
#else
GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = no\n", __func__);
#endif
GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");

ctx->capture_next_compute = false;
ctx->capture_started = false;
Expand Down Expand Up @@ -3153,10 +3154,12 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);

#ifdef GGML_METAL_FORCE_FATTN_PREC_F32
#ifdef GGML_METAL_FORCE_FATTN_PREC_F16
const enum ggml_prec prec = GGML_PREC_DEFAULT;
#else
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
// TODO: support both precisions
const enum ggml_prec prec = GGML_PREC_F32;
//const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
#endif

const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2;
Expand Down
Loading

0 comments on commit a797e5d

Please sign in to comment.