Skip to content

Commit

Permalink
mtl : adapt the MNIST example as starter
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 29, 2023
1 parent 98c267f commit b23fe8c
Show file tree
Hide file tree
Showing 4 changed files with 458 additions and 0 deletions.
22 changes: 22 additions & 0 deletions examples/mtl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@ set(TARGET mtl-export)
add_executable(${TARGET} mtl-export.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

if (APPLE)
#
# mtl

find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)

set(TEST_TARGET mtl)
add_executable(${TEST_TARGET} mtl.cpp mtl.h mtl.m)
target_link_libraries(${TEST_TARGET} PRIVATE
ggml
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()

51 changes: 51 additions & 0 deletions examples/mtl/mtl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "ggml.h"
#include "mtl.h"

#include <cstdio>
#include <cstring>
#include <cstdlib>

int main(int argc, char ** argv) {
ggml_time_init();

if (argc != 2) {
fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]);
return -1;
}

const char * fname_cgraph = argv[1];

// load the compute graph
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;

struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
gf.n_threads = 1;

// allocate work context
static size_t buf_size = gf.work_size; // TODO
static void * buf = malloc(buf_size);

struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};

struct ggml_context * ctx_work = ggml_init(params);

// this allocates all Metal resources and memory buffers
auto * ctx_mtl = llama_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);

// the actual inference happens here
llama_mtl_eval(ctx_mtl, &gf);

llama_mtl_free(ctx_mtl);

ggml_free(ctx_work);
ggml_free(ctx_data);
ggml_free(ctx_eval);

return 0;
}

28 changes: 28 additions & 0 deletions examples/mtl/mtl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

struct ggml_context;
struct ggml_cgraph;

#ifdef __cplusplus
extern "C" {
#endif

struct ggml_mtl_context;

struct ggml_mtl_context * llama_mtl_init(
struct ggml_context * ctx_data,
struct ggml_context * ctx_eval,
struct ggml_context * ctx_work,
struct ggml_cgraph * gf);

void llama_mtl_free(struct ggml_mtl_context * ctx);

// return 0 on success
int llama_mtl_eval(
struct ggml_mtl_context * ctx,
struct ggml_cgraph * gf);

#ifdef __cplusplus
}
#endif

Loading

0 comments on commit b23fe8c

Please sign in to comment.