Skip to content

Commit

Permalink
mtl : matrix multiplication support
Browse files Browse the repository at this point in the history
Seems to be only slightly faster compared to AMX.
Probably need to optimize the MTL buffer creation
  • Loading branch information
ggerganov committed Nov 8, 2022
1 parent a866b1c commit 2c6fc25
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 9 deletions.
11 changes: 10 additions & 1 deletion ggml-mtl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@ struct ggml_mtl_context * ggml_mtl_init(void);

struct ggml_mtl_object ggml_mtl_alloc(size_t size);

void ggml_mtl_mul_mat_f16(
void ggml_mtl_mul_mat_vec_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0,
const __fp16 * src1,
float * dst,
int nrows,
int ncols);

void ggml_mtl_mul_mat_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0,
const __fp16 * src1,
float * dst,
int nrows0,
int nrows1,
int ncols);
54 changes: 53 additions & 1 deletion ggml-mtl.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct ggml_mtl_object ggml_mtl_alloc(size_t size) {
};

// multiply matrix with a vector using MPSMatrixVectorMultiplication
void ggml_mtl_mul_mat_f16(
void ggml_mtl_mul_mat_vec_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0, // matrix f16
const __fp16 * src1, // vector f16
Expand Down Expand Up @@ -104,3 +104,55 @@ void ggml_mtl_mul_mat_f16(
// copy GPU result to CPU
memcpy(dst, [dst_buffer contents], nrows*sizeof(float));
}

// multiply matrix with a matrix using MPSMatrixMultiplication
void ggml_mtl_mul_mat_f16(
struct ggml_mtl_context * ctx,
struct ggml_mtl_object src0, // matrix f16
const __fp16 * src1, // vector f16
float * dst, // vector f32
int nrows0,
int nrows1,
int ncols) {
// Create a command buffer to hold commands.
id<MTLCommandBuffer> commandBuffer = [g_command_queue commandBuffer];
assert(commandBuffer != nil);

// make managed device buffer to store src1
id<MTLBuffer> src1_buffer = [g_device newBufferWithBytes:src1 length:ncols*nrows1*sizeof(__fp16) options:MTLResourceStorageModeManaged];
id<MTLBuffer> dst_buffer = [g_device newBufferWithLength:nrows0*nrows1*sizeof(float) options:MTLResourceStorageModeManaged];

// MPSMatrixDescriptor
MPSMatrixDescriptor *src0_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows0 columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSMatrixDescriptor *src1_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows1 columns:ncols rowBytes:ncols*sizeof(__fp16) dataType:MPSDataTypeFloat16];
MPSMatrixDescriptor *dst_desc = [MPSMatrixDescriptor matrixDescriptorWithRows:nrows1 columns:nrows0 rowBytes:nrows0*sizeof(float) dataType:MPSDataTypeFloat32];

// MPSMatrix
MPSMatrix *src0_mat = [[MPSMatrix alloc] initWithBuffer:g_buffers[src0.id] descriptor:src0_desc];
MPSMatrix *src1_mat = [[MPSMatrix alloc] initWithBuffer:src1_buffer descriptor:src1_desc];
MPSMatrix *dst_mat = [[MPSMatrix alloc] initWithBuffer:dst_buffer descriptor:dst_desc];

//// MPSMatrixMultiplication z = x * yT
//MPSMatrixMultiplication *mul_mat = [[MPSMatrixMultiplication alloc] initWithDevice:g_device transposeLeft:NO transposeRight:YES resultRows:nrows resultColumns:nrows interiorColumns:ncols alpha:1.0 beta:0.0];

//// encode
//[mul_mat encodeToCommandBuffer:commandBuffer
// leftMatrix:src0_mat
// rightMatrix:src1_mat
// resultMatrix:dst_mat];

// MPSMatrixMultiplication zT = xT * y
MPSMatrixMultiplication *mul_mat = [[MPSMatrixMultiplication alloc] initWithDevice:g_device transposeLeft:NO transposeRight:YES resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols alpha:1.0 beta:0.0];

// encode
[mul_mat encodeToCommandBuffer:commandBuffer
leftMatrix:src1_mat
rightMatrix:src0_mat
resultMatrix:dst_mat];

[commandBuffer commit];
[commandBuffer waitUntilCompleted];

// copy GPU result to CPU
memcpy(dst, [dst_buffer contents], nrows0*nrows1*sizeof(float));
}
8 changes: 3 additions & 5 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4463,7 +4463,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
// compute by src0 columns

#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst) && (src0->id < 0)) {
GGML_ASSERT(nb10 == sizeof(float));

if (params->ith != 0) return;
Expand Down Expand Up @@ -4584,7 +4584,7 @@ void ggml_compute_forward_mul_mat_f16_f32(
return;
}

bool is_mtl = src0->id >= 0 && src1->ne[1] == 1;
bool is_mtl = src0->id >= 0;

if (nb01 >= nb00) {
// fp16 -> half the size, so divide by 2
Expand All @@ -4593,16 +4593,14 @@ void ggml_compute_forward_mul_mat_f16_f32(

// parallelize by src0 rows using ggml_vec_dot_f32

const int nmtl = (1*ne01);

if (is_mtl) {
assert(ne02 == 1);
assert(ne03 == 1);

if (params->ith == 0) {
struct ggml_mtl_object src0_mtl = { src0->id, src0->data };
ggml_fp16_t * src1_fp16 = params->wdata;
ggml_mtl_mul_mat_f16(NULL, src0_mtl, src1_fp16, dst->data, nmtl, ne00);
ggml_mtl_mul_mat_f16(NULL, src0_mtl, src1_fp16, dst->data, ne01, ne11, ne00);
}
return;
}
Expand Down
4 changes: 2 additions & 2 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,10 +788,10 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);

layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
layer.mlp_0_w = ggml_new_tensor_2d_mtl(ctx, wtype, n_audio_state, 4*n_audio_state);
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);

layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
layer.mlp_1_w = ggml_new_tensor_2d_mtl(ctx, wtype, 4*n_audio_state, n_audio_state);
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);

layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
Expand Down

0 comments on commit 2c6fc25

Please sign in to comment.