Skip to content

Commit

Permalink
metal : add im2col support + mul mat-vec f16 x f16
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 10, 2023
1 parent 9c1ddc7 commit 997f7cb
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 40 deletions.
21 changes: 12 additions & 9 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4736,7 +4736,10 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}

static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
static __global__ void im2col_f32_f16(
const float * x, half * dst,
int ofs0, int ofs1, int IW, int IH, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;

Expand Down Expand Up @@ -5734,11 +5737,10 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
}

static void im2col_f32_f16_cuda(const float* x, half* dst,
int OH, int IW, int IH,
int OW, int IC,
int KH, int KW, int N, int ofs0, int ofs1,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
static void im2col_f32_f16_cuda(const float * x, half * dst,
int OH, int IW, int IH, int OW, int IC,
int KH, int KW, int N, int ofs0, int ofs1,
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
dim3 block_nums(IC, OH, OW);
dim3 block_dims(N, KH, KW);
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
Expand Down Expand Up @@ -6730,11 +6732,12 @@ inline void ggml_cuda_op_im2col(
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];

const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32

im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
OH, IW, IH, OW, IC, KH, KW, N,
src1->nb[is_2D ? 3 : 2] / 4, // nb is byte offset, src is type float32
src1->nb[is_2D ? 2 : 1] / 4, // nb is byte offset, src is type float32
s0, s1, p0, p1, d0, d1, main_stream);
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);

(void) src0;
(void) src0_dd;
Expand Down
76 changes: 70 additions & 6 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
Expand Down Expand Up @@ -114,6 +115,7 @@
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(im2col_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
Expand Down Expand Up @@ -287,6 +289,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
Expand Down Expand Up @@ -317,6 +320,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(im2col_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
Expand Down Expand Up @@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
Expand Down Expand Up @@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(im2col_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
Expand Down Expand Up @@ -1139,20 +1145,26 @@ void ggml_metal_graph_compute(
switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
nrows = 4;
} break;
case GGML_TYPE_F16:
{
nth0 = 32;
nth1 = 1;
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
nrows = ne11;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
nrows = ne11;
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
nrows = 4;
}
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
nrows = 4;
}
} break;
Expand Down Expand Up @@ -1464,6 +1476,58 @@ void ggml_metal_graph_compute(

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);

const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;

const int32_t N = src1->ne[is_2D ? 3 : 2];
const int32_t IC = src1->ne[is_2D ? 2 : 1];
const int32_t IH = is_2D ? src1->ne[1] : 1;
const int32_t IW = src1->ne[0];

const int32_t KH = is_2D ? src0->ne[1] : 1;
const int32_t KW = src0->ne[0];

const int32_t OH = is_2D ? dst->ne[2] : 1;
const int32_t OW = dst->ne[1];

const int32_t CHW = IC * KH * KW;

const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;

switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
default: GGML_ASSERT(false);
};

[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];

[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
Expand Down
108 changes: 107 additions & 1 deletion ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
uint tiisg[[thread_index_in_simdgroup]]) {

const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32;
Expand Down Expand Up @@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
}
}

#define N_F16_F16 4

kernel void kernel_mul_mv_f16_f16(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {

const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F16;
const int64_t im = tgpig.z;

device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);

if (ne00 < 128) {
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}

device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);

float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (half) x[i] * (half) y[i];
}

float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}

device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
device const half4 * y4 = (device const half4 *) y;

float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
}

float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}

kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0,
device const char * src1,
Expand Down Expand Up @@ -1229,6 +1302,39 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;

kernel void kernel_im2col_f16(
device const float * x,
device half * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;

const int32_t offset_dst =
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);

if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
} else {
dst[offset_dst] = 0.0f;
}
}

kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,
Expand Down
53 changes: 29 additions & 24 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5131,13 +5131,15 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
int s0,
int p0,
int d0) {
struct ggml_tensor * result = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
result =
ggml_reshape_3d(ctx,
ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, result, result->ne[0], (result->ne[2] * result->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])), // [OC,IC, K] => [OC, IC * K]
result->ne[1], a->ne[2], result->ne[2]); // [N, OC, OL]
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]

struct ggml_tensor * result =
ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]

result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]

return result;
}

Expand Down Expand Up @@ -5252,22 +5254,24 @@ struct ggml_tensor * ggml_im2col(
// b: [N, IC, IH, IW]
// result: [N, OC, OH, OW]
struct ggml_tensor * ggml_conv_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1) {
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1) {
struct ggml_tensor * result = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]

result =
ggml_reshape_4d(ctx,
ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, result, result->ne[0], result->ne[3] * result->ne[2] * result->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])), // [OC,IC, KH, KW] => [OC, IC * KH * KW]
result->ne[1], result->ne[2], a->ne[3], result->ne[3]); // [N, OC, OH, OW]

return result;
}

Expand Down Expand Up @@ -11724,17 +11728,18 @@ static void ggml_compute_forward_im2col_f16(

GGML_TENSOR_BINARY_OP_LOCALS;

const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;

const int ith = params->ith;
const int nth = params->nth;
const int64_t N = is_2D ? ne13 : ne12;

const int64_t N = is_2D ? ne13 : ne12;
const int64_t IC = is_2D ? ne12 : ne11;
const int64_t IH = is_2D ? ne11 : 1;
const int64_t IW = ne10;
Expand Down

0 comments on commit 997f7cb

Please sign in to comment.