Skip to content

Commit

Permalink
cuda : fix im2col kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 10, 2023
1 parent 000b952 commit 9c1ddc7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
27 changes: 16 additions & 11 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4737,13 +4737,18 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
}

static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
__syncthreads();
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;

const int offset_dst =
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);

if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW;
int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]);
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
} else {
dst[offset_dst] = __float2half(0.0f);
}
}

Expand Down Expand Up @@ -5735,7 +5740,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
int KH, int KW, int N, int ofs0, int ofs1,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
dim3 block_nums(IC, OH, OW);
dim3 block_dims(N, KH, KW);
dim3 block_dims(N, KH, KW);
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}

Expand Down Expand Up @@ -6714,16 +6719,16 @@ inline void ggml_cuda_op_im2col(

const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;

const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t IW = src1->ne[0];

const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t KW = src0->ne[0];

const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const int64_t OW = dst->ne[1];

im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
OH, IW, IH, OW, IC, KH, KW, N,
Expand Down
4 changes: 2 additions & 2 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5227,13 +5227,13 @@ struct ggml_tensor * ggml_im2col(
}

const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);

const int64_t ne[4] = {
is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
OW,
is_2D ? OH : b->ne[2],
is_2D ? b->ne[3] : 1,
is_2D ? b->ne[3] : 1,
};

struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
Expand Down
24 changes: 12 additions & 12 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1604,22 +1604,22 @@ static struct ggml_cgraph * whisper_build_graph_conv(
// convolution + gelu
{
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
//cur = ggml_add(ctx0, cur, model.e_conv_1_b);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
model.e_conv_1_b,
cur),
cur);
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
//cur = ggml_add(ctx0,
// ggml_repeat(ctx0,
// model.e_conv_1_b,
// cur),
// cur);

cur = ggml_gelu(ctx0, cur);

cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
//cur = ggml_add(ctx0, cur, model.e_conv_2_b);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
model.e_conv_2_b,
cur),
cur);
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
//cur = ggml_add(ctx0,
// ggml_repeat(ctx0,
// model.e_conv_2_b,
// cur),
// cur);

cur = ggml_gelu(ctx0, cur);
}
Expand Down

0 comments on commit 9c1ddc7

Please sign in to comment.