Skip to content

Commit

Permalink
mtl : matrix multiplication support
Browse files Browse the repository at this point in the history
Seems to be only marginally faster compared to pure AMX
  • Loading branch information
ggerganov committed Nov 9, 2022
1 parent 4e5674a commit b5d3521
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 5 deletions.
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ if (APPLE AND NOT WHISPER_NO_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
endif()

find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)

set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK})
endif()

if (WHISPER_SUPPORT_OPENBLAS)
Expand Down Expand Up @@ -168,6 +179,7 @@ set(TARGET whisper)

add_library(${TARGET}
ggml.c
ggml-mtl.m
whisper.cpp
)

Expand Down
38 changes: 38 additions & 0 deletions ggml-mtl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include <stdint.h>
#include <stddef.h>

// TODO: this will hold dynamic context data in the future
// currently unused
struct ggml_mtl_context {
void * dummy;
};

struct ggml_mtl_object {
int32_t id;
void * data;
};

struct ggml_mtl_context * ggml_mtl_init(void);

struct ggml_mtl_object ggml_mtl_alloc(size_t size);

// multiply matrix by vector
void ggml_mtl_mul_mat_vec_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0, // matrix f16
const __fp16 * src1, // vector f16
float * dst, // vector f32
int nrows,
int ncols);

// multiply matrix by matrix
void ggml_mtl_mul_mat_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0, // matrix f16
const __fp16 * src1, // matrix f16
float * dst, // matrix f32
int nrows0,
int nrows1,
int ncols);
162 changes: 162 additions & 0 deletions ggml-mtl.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#import "ggml-mtl.h"

#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>

#define GGML_MTL_MAX_BUFFERS 256

// global static storage for Metal buffers
// TODO: move this into a dynamic context
static id<MTLBuffer> g_buffers[GGML_MTL_MAX_BUFFERS];

// global MTL context
// TODO: move this into a dynamic context
static id<MTLDevice> g_device;
static id<MTLCommandQueue> g_command_queue;

struct ggml_mtl_context * ggml_mtl_init() {
// TODO: implement properly
// for now, init the global MTL context and MTL buffers
g_device = MTLCreateSystemDefaultDevice();

g_command_queue = [g_device newCommandQueue];
if (g_command_queue == nil)
{
NSLog(@"Failed to find the command queue.");
return nil;
}

return nil;
}

// search for unallocated buffer slot and use it
struct ggml_mtl_object ggml_mtl_alloc(size_t size) {
// TODO: temporarily making sure that the buffers are nil at the start
static bool first = true;
if (first) {
for (int i = 0; i < GGML_MTL_MAX_BUFFERS; ++i) {
assert(g_buffers[i] == nil);
}
first = false;
}

struct ggml_mtl_object obj = { -1, nil };

for (int i = 0; i < GGML_MTL_MAX_BUFFERS; i++) {
if (g_buffers[i] == nil) {
g_buffers[i] = [g_device newBufferWithLength:size options:MTLResourceStorageModeManaged];

// lunk the MTL buffer to the ggml object
obj.id = i;
obj.data = [g_buffers[i] contents];

break;
}
}

return obj;
}

struct params_mul_mat_vec {
int N; // rows
int M; // cols
};

// multiply matrix with a vector using MPSMatrixVectorMultiplication
void ggml_mtl_mul_mat_vec_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0,
const __fp16 * src1,
float * dst,
int nrows,
int ncols) {
(void) ctx; // unused

// Create a command buffer to hold commands.
id<MTLCommandBuffer> commandBuffer = [g_command_queue commandBuffer];
assert(commandBuffer != nil);

// make managed device buffer to store src1
id<MTLBuffer> src1_buffer = [g_device newBufferWithBytes:src1 length:ncols*sizeof(__fp16) options:MTLResourceStorageModeManaged];
id<MTLBuffer> dst_buffer = [g_device newBufferWithLength:nrows*sizeof(float) options:MTLResourceStorageModeManaged];

// MPSMatrixDescriptor
MPSMatrixDescriptor *src0_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSVectorDescriptor *src1_desc = [MPSVectorDescriptor vectorDescriptorWithLength:ncols dataType:MPSDataTypeFloat16];
MPSVectorDescriptor *dst_desc = [MPSVectorDescriptor vectorDescriptorWithLength:nrows dataType:MPSDataTypeFloat32];

// MPSMatrix
MPSMatrix *src0_mat = [[MPSMatrix alloc] initWithBuffer:g_buffers[src0.id] descriptor:src0_desc];
MPSVector *src1_vec = [[MPSVector alloc] initWithBuffer:src1_buffer descriptor:src1_desc];
MPSVector *dst_vec = [[MPSVector alloc] initWithBuffer:dst_buffer descriptor:dst_desc];

// MPSMatrixVectorMultiplication
MPSMatrixVectorMultiplication *mul_mat_vec = [[MPSMatrixVectorMultiplication alloc] initWithDevice:g_device transpose:NO rows:nrows columns:ncols alpha:1.0 beta:0.0];

// encode
[mul_mat_vec encodeToCommandBuffer:commandBuffer
inputMatrix:src0_mat
inputVector:src1_vec
resultVector:dst_vec];

[commandBuffer commit];
[commandBuffer waitUntilCompleted];

// copy GPU result to CPU
memcpy(dst, [dst_buffer contents], nrows*sizeof(float));
}

// multiply matrix with a matrix using MPSMatrixMultiplication
void ggml_mtl_mul_mat_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0,
const __fp16 * src1,
float * dst,
int nrows0,
int nrows1,
int ncols) {
(void) ctx; // unused

// Create a command buffer to hold commands.
id<MTLCommandBuffer> commandBuffer = [g_command_queue commandBuffer];
assert(commandBuffer != nil);

// make managed device buffer to store src1
id<MTLBuffer> src1_buffer = [g_device newBufferWithBytes:src1 length:ncols*nrows1*sizeof(__fp16) options:MTLResourceStorageModeManaged];
id<MTLBuffer> dst_buffer = [g_device newBufferWithLength:nrows0*nrows1*sizeof(float) options:MTLResourceStorageModeManaged];

// MPSMatrixDescriptor
MPSMatrixDescriptor *src0_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows0 columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSMatrixDescriptor *src1_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows1 columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSMatrixDescriptor *dst_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows1 columns:nrows0 rowBytes:nrows0*sizeof(float) dataType:MPSDataTypeFloat32];

// MPSMatrix
MPSMatrix *src0_mat = [[MPSMatrix alloc] initWithBuffer:g_buffers[src0.id] descriptor:src0_desc];
MPSMatrix *src1_mat = [[MPSMatrix alloc] initWithBuffer:src1_buffer descriptor:src1_desc];
MPSMatrix *dst_mat = [[MPSMatrix alloc] initWithBuffer:dst_buffer descriptor:dst_desc];

//// MPSMatrixMultiplication z = x * yT
//MPSMatrixMultiplication *mul_mat = [[MPSMatrixMultiplication alloc] initWithDevice:g_device transposeLeft:NO transposeRight:YES resultRows:nrows resultColumns:nrows interiorColumns:ncols alpha:1.0 beta:0.0];

//// encode
//[mul_mat encodeToCommandBuffer:commandBuffer
// leftMatrix:src0_mat
// rightMatrix:src1_mat
// resultMatrix:dst_mat];

// MPSMatrixMultiplication zT = xT * y
MPSMatrixMultiplication *mul_mat = [[MPSMatrixMultiplication alloc] initWithDevice:g_device transposeLeft:NO transposeRight:YES resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols alpha:1.0 beta:0.0];

// encode
[mul_mat encodeToCommandBuffer:commandBuffer
leftMatrix:src1_mat
rightMatrix:src0_mat
resultMatrix:dst_mat];

[commandBuffer commit];
[commandBuffer waitUntilCompleted];

// copy GPU result to CPU
memcpy(dst, [dst_buffer contents], nrows0*nrows1*sizeof(float));
}
Loading

0 comments on commit b5d3521

Please sign in to comment.