From 9c1ddc77a7c03d017aea00a40b82487ab1a365b5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Nov 2023 19:39:24 +0200 Subject: [PATCH] cuda : fix im2col kernel --- ggml-cuda.cu | 27 ++++++++++++++++----------- ggml.c | 4 ++-- whisper.cpp | 24 ++++++++++++------------ 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 221214424b2..681fe4948ec 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4737,13 +4737,18 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, } static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) { - int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0; - int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; - __syncthreads(); + const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0; + const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; + + const int offset_dst = + (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW + + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z); + if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { - int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW; - int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1; - dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]); + const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1; + dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); + } else { + dst[offset_dst] = __float2half(0.0f); } } @@ -5735,7 +5740,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst, int KH, int KW, int N, int ofs0, int ofs1, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { dim3 block_nums(IC, OH, OW); - dim3 block_dims(N, KH, KW); + dim3 block_dims(N, KH, KW); im2col_f32_f16<<>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } @@ -6714,16 +6719,16 @@ inline void ggml_cuda_op_im2col( const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; - const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t N = src1->ne[is_2D ? 3 : 2]; const int64_t IC = src1->ne[is_2D ? 2 : 1]; const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; + const int64_t IW = src1->ne[0]; const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; + const int64_t KW = src0->ne[0]; const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; + const int64_t OW = dst->ne[1]; im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, OH, IW, IH, OW, IC, KH, KW, N, diff --git a/ggml.c b/ggml.c index 8d30cfda3d3..a6fb2817523 100644 --- a/ggml.c +++ b/ggml.c @@ -5227,13 +5227,13 @@ struct ggml_tensor * ggml_im2col( } const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; - const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); const int64_t ne[4] = { is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], OW, is_2D ? OH : b->ne[2], - is_2D ? b->ne[3] : 1, + is_2D ? b->ne[3] : 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); diff --git a/whisper.cpp b/whisper.cpp index 1371a6c921f..80ca5c9bb56 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1604,22 +1604,22 @@ static struct ggml_cgraph * whisper_build_graph_conv( // convolution + gelu { cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); - //cur = ggml_add(ctx0, cur, model.e_conv_1_b); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_1_b, - cur), - cur); + cur = ggml_add(ctx0, cur, model.e_conv_1_b); + //cur = ggml_add(ctx0, + // ggml_repeat(ctx0, + // model.e_conv_1_b, + // cur), + // cur); cur = ggml_gelu(ctx0, cur); cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); - //cur = ggml_add(ctx0, cur, model.e_conv_2_b); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_2_b, - cur), - cur); + cur = ggml_add(ctx0, cur, model.e_conv_2_b); + //cur = ggml_add(ctx0, + // ggml_repeat(ctx0, + // model.e_conv_2_b, + // cur), + // cur); cur = ggml_gelu(ctx0, cur); }