Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal support #127

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ if (APPLE AND NOT WHISPER_NO_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
endif()

find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)

set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK})
endif()

if (WHISPER_SUPPORT_OPENBLAS)
Expand Down Expand Up @@ -168,6 +179,7 @@ set(TARGET whisper)

add_library(${TARGET}
ggml.c
ggml-mtl.m
whisper.cpp
)

Expand Down
15 changes: 9 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ endif
ifndef WHISPER_NO_ACCELERATE
# Mac M1 - include Accelerate framework
ifeq ($(UNAME_S),Darwin)
CFLAGS += -DGGML_USE_ACCELERATE
LDFLAGS += -framework Accelerate
CFLAGS += -DGGML_USE_ACCELERATE -DGGML_PERF
LDFLAGS += -framework Foundation -framework Accelerate -framework Metal -framework MetalKit -framework MetalPerformanceShaders
endif
endif
ifneq ($(filter aarch64%,$(UNAME_M)),)
Expand All @@ -81,18 +81,21 @@ endif
# Build library + main
#

main: examples/main/main.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp whisper.o ggml.o -o main $(LDFLAGS)
main: examples/main/main.cpp ggml.o ggml-mtl.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp whisper.o ggml.o ggml-mtl.o -o main $(LDFLAGS)
./main -h

ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c -o ggml.o

ggml-mtl.o: ggml-mtl.m ggml-mtl.h
$(CC) $(CFLAGS) -c ggml-mtl.m -o ggml-mtl.o

whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o

libwhisper.a: ggml.o whisper.o
$(AR) rcs libwhisper.a ggml.o whisper.o
libwhisper.a: ggml.o ggml-mtl.o whisper.o
$(AR) rcs libwhisper.a ggml.o ggml-mtl.o whisper.o

clean:
rm -f *.o main stream bench libwhisper.a
Expand Down
38 changes: 38 additions & 0 deletions ggml-mtl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include <stdint.h>
#include <stddef.h>

// TODO: this will hold dynamic context data in the future
// currently unused
struct ggml_mtl_context {
void * dummy;
};

struct ggml_mtl_object {
int32_t id;
void * data;
};

struct ggml_mtl_context * ggml_mtl_init(void);

struct ggml_mtl_object ggml_mtl_alloc(size_t size);

// multiply matrix by vector
void ggml_mtl_mul_mat_vec_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0, // matrix f16
const __fp16 * src1, // vector f16
float * dst, // vector f32
int nrows,
int ncols);

// multiply matrix by matrix
void ggml_mtl_mul_mat_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0, // matrix f16
const __fp16 * src1, // matrix f16
float * dst, // matrix f32
int nrows0,
int nrows1,
int ncols);
162 changes: 162 additions & 0 deletions ggml-mtl.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#import "ggml-mtl.h"

#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>

#define GGML_MTL_MAX_BUFFERS 256

// global static storage for Metal buffers
// TODO: move this into a dynamic context
static id<MTLBuffer> g_buffers[GGML_MTL_MAX_BUFFERS];

// global MTL context
// TODO: move this into a dynamic context
static id<MTLDevice> g_device;
static id<MTLCommandQueue> g_command_queue;

struct ggml_mtl_context * ggml_mtl_init() {
// TODO: implement properly
// for now, init the global MTL context and MTL buffers
g_device = MTLCreateSystemDefaultDevice();

g_command_queue = [g_device newCommandQueue];
if (g_command_queue == nil)
{
NSLog(@"Failed to find the command queue.");
return nil;
}

return nil;
}

// search for unallocated buffer slot and use it
struct ggml_mtl_object ggml_mtl_alloc(size_t size) {
// TODO: temporarily making sure that the buffers are nil at the start
static bool first = true;
if (first) {
for (int i = 0; i < GGML_MTL_MAX_BUFFERS; ++i) {
assert(g_buffers[i] == nil);
}
first = false;
}

struct ggml_mtl_object obj = { -1, nil };

for (int i = 0; i < GGML_MTL_MAX_BUFFERS; i++) {
if (g_buffers[i] == nil) {
g_buffers[i] = [g_device newBufferWithLength:size options:MTLResourceStorageModeManaged];

// lunk the MTL buffer to the ggml object
obj.id = i;
obj.data = [g_buffers[i] contents];

break;
}
}

return obj;
}

struct params_mul_mat_vec {
int N; // rows
int M; // cols
};

// multiply matrix with a vector using MPSMatrixVectorMultiplication
void ggml_mtl_mul_mat_vec_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0,
const __fp16 * src1,
float * dst,
int nrows,
int ncols) {
(void) ctx; // unused

// Create a command buffer to hold commands.
id<MTLCommandBuffer> commandBuffer = [g_command_queue commandBuffer];
assert(commandBuffer != nil);

// make managed device buffer to store src1
id<MTLBuffer> src1_buffer = [g_device newBufferWithBytes:src1 length:ncols*sizeof(__fp16) options:MTLResourceStorageModeManaged];
id<MTLBuffer> dst_buffer = [g_device newBufferWithLength:nrows*sizeof(float) options:MTLResourceStorageModeManaged];

// MPSMatrixDescriptor
MPSMatrixDescriptor *src0_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSVectorDescriptor *src1_desc = [MPSVectorDescriptor vectorDescriptorWithLength:ncols dataType:MPSDataTypeFloat16];
MPSVectorDescriptor *dst_desc = [MPSVectorDescriptor vectorDescriptorWithLength:nrows dataType:MPSDataTypeFloat32];

// MPSMatrix
MPSMatrix *src0_mat = [[MPSMatrix alloc] initWithBuffer:g_buffers[src0.id] descriptor:src0_desc];
MPSVector *src1_vec = [[MPSVector alloc] initWithBuffer:src1_buffer descriptor:src1_desc];
MPSVector *dst_vec = [[MPSVector alloc] initWithBuffer:dst_buffer descriptor:dst_desc];

// MPSMatrixVectorMultiplication
MPSMatrixVectorMultiplication *mul_mat_vec = [[MPSMatrixVectorMultiplication alloc] initWithDevice:g_device transpose:NO rows:nrows columns:ncols alpha:1.0 beta:0.0];

// encode
[mul_mat_vec encodeToCommandBuffer:commandBuffer
inputMatrix:src0_mat
inputVector:src1_vec
resultVector:dst_vec];

[commandBuffer commit];
[commandBuffer waitUntilCompleted];

// copy GPU result to CPU
memcpy(dst, [dst_buffer contents], nrows*sizeof(float));
}

// multiply matrix with a matrix using MPSMatrixMultiplication
void ggml_mtl_mul_mat_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0,
const __fp16 * src1,
float * dst,
int nrows0,
int nrows1,
int ncols) {
(void) ctx; // unused

// Create a command buffer to hold commands.
id<MTLCommandBuffer> commandBuffer = [g_command_queue commandBuffer];
assert(commandBuffer != nil);

// make managed device buffer to store src1
id<MTLBuffer> src1_buffer = [g_device newBufferWithBytes:src1 length:ncols*nrows1*sizeof(__fp16) options:MTLResourceStorageModeManaged];
id<MTLBuffer> dst_buffer = [g_device newBufferWithLength:nrows0*nrows1*sizeof(float) options:MTLResourceStorageModeManaged];

// MPSMatrixDescriptor
MPSMatrixDescriptor *src0_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows0 columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSMatrixDescriptor *src1_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows1 columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSMatrixDescriptor *dst_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows1 columns:nrows0 rowBytes:nrows0*sizeof(float) dataType:MPSDataTypeFloat32];

// MPSMatrix
MPSMatrix *src0_mat = [[MPSMatrix alloc] initWithBuffer:g_buffers[src0.id] descriptor:src0_desc];
MPSMatrix *src1_mat = [[MPSMatrix alloc] initWithBuffer:src1_buffer descriptor:src1_desc];
MPSMatrix *dst_mat = [[MPSMatrix alloc] initWithBuffer:dst_buffer descriptor:dst_desc];

//// MPSMatrixMultiplication z = x * yT
//MPSMatrixMultiplication *mul_mat = [[MPSMatrixMultiplication alloc] initWithDevice:g_device transposeLeft:NO transposeRight:YES resultRows:nrows resultColumns:nrows interiorColumns:ncols alpha:1.0 beta:0.0];

//// encode
//[mul_mat encodeToCommandBuffer:commandBuffer
// leftMatrix:src0_mat
// rightMatrix:src1_mat
// resultMatrix:dst_mat];

// MPSMatrixMultiplication zT = xT * y
MPSMatrixMultiplication *mul_mat = [[MPSMatrixMultiplication alloc] initWithDevice:g_device transposeLeft:NO transposeRight:YES resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols alpha:1.0 beta:0.0];

// encode
[mul_mat encodeToCommandBuffer:commandBuffer
leftMatrix:src1_mat
rightMatrix:src0_mat
resultMatrix:dst_mat];

[commandBuffer commit];
[commandBuffer waitUntilCompleted];

// copy GPU result to CPU
memcpy(dst, [dst_buffer contents], nrows0*nrows1*sizeof(float));
}
Loading