Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda : add batched cuBLAS GEMM for faster attention #3749

Merged
merged 10 commits into from
Oct 24, 2023
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ if (LLAMA_CUBLAS)
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
else()
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
#set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
endif()
endif()
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
Expand Down
11 changes: 9 additions & 2 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
gpt_params params;

if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]);
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]);
return 1 ;
}

Expand All @@ -21,6 +21,9 @@ int main(int argc, char ** argv) {
// total length of the sequences including the prompt
int n_len = 32;

// number of layers to offload to the GPU
int n_gpu_layers = 0;

if (argc >= 2) {
params.model = argv[1];
}
Expand All @@ -37,6 +40,10 @@ int main(int argc, char ** argv) {
n_len = std::atoi(argv[4]);
}

if (argc >= 6) {
n_gpu_layers = std::atoi(argv[5]);
}

if (params.prompt.empty()) {
params.prompt = "Hello my name is";
}
Expand All @@ -49,7 +56,7 @@ int main(int argc, char ** argv) {

llama_model_params model_params = llama_model_default_params();

// model_params.n_gpu_layers = 99; // offload all layers to the GPU
model_params.n_gpu_layers = n_gpu_layers;

llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

Expand Down
183 changes: 172 additions & 11 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4326,13 +4326,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous

const half * x = (const half *) vx;

const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int channel_x = channel / channel_x_divisor;

const int nrows_y = ncols_x;
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
const int row_dst = row_x;

const int idst = channel*nrows_dst + row_dst;

Expand All @@ -4345,13 +4345,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
break;
}

const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const float xi = __half2float(x[ix]);

const int row_y = col_x;

const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const int iy = channel*nrows_y + row_y;

const float xi = __half2float(x[ix]);

tmp += xi * y[iy];
}

Expand Down Expand Up @@ -7013,7 +7013,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
}

static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
Expand All @@ -7023,11 +7024,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];

const int64_t ne12 = src1->ne[2];

const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];

const int64_t ne12 = src1->ne[2];

CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];

Expand All @@ -7046,6 +7047,154 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
}

static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);

const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];

const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);

const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];

const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);

const int64_t ne1 = ggml_nelements(src1);
const int64_t ne = ggml_nelements(dst);

CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];

int id;
CUDA_CHECK(cudaGetDevice(&id));
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));

ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
half * src0_as_f16 = (half *) src0_ddq;

ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];

ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];

// convert src1 to fp16
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);

size_t src1_as = 0;
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);

size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);

// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

#if 0
// use cublasGemmEx
{
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
int i03 = i13 / r3;
int i02 = i12 / r2;

CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
}
#else
// use cublasGemmBatchedEx
{
const int ne23 = ne12*ne13;

// TODO: avoid this alloc
void ** src0_ptrs = (void **) malloc(ne23*sizeof(void *));
void ** src1_ptrs = (void **) malloc(ne23*sizeof(void *));
void ** dst_ptrs = (void **) malloc(ne23*sizeof(void *));

for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
int i03 = i13 / r3;
int i02 = i12 / r2;

src0_ptrs[i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
src1_ptrs[i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
dst_ptrs [i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
}
}

// allocate device memory for pointers
void ** src0_ptrs_as = nullptr;
void ** src1_ptrs_as = nullptr;
void ** dst_ptrs_as = nullptr;
KerfuffleV2 marked this conversation as resolved.
Show resolved Hide resolved

CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *)));
CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *)));
CUDA_CHECK(cudaMalloc(& dst_ptrs_as, ne23*sizeof(void *)));

// copy pointers to device
CUDA_CHECK(cudaMemcpy(src0_ptrs_as, src0_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(src1_ptrs_as, src1_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy( dst_ptrs_as, dst_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
Copy link
Owner Author

@ggerganov ggerganov Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does anyone know if I changed these cudaMemcpy to cudaMemcpyAsync, do I need to add some synchronization before calling cublasGemmBatchedEx?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only need to run them on the same stream, but I don't think that this can be made async because the host memory may already be freed by the time the copy happens. Running memcpy asynchronously also requires using host pinned memory.

If the cublasGemmBatchedEx needs to stay to support GQA, I would consider writing a kernel to calculate these values and calling cublas from the kernel. Additionally, cudaMalloc is usually very slow, which is why we have the memory pool. These allocations should be changed to use the memory pool.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reduced the mallocs from 3 to 1, but when I try to replace it with ggml_cuda_pool_malloc() (see commented code) the computation crashes with illegal access memory somewhere later. Couldn't figure out what is the cause - probably some memory alignment issue in the pointer arrays.


CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half),
(void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, (void **) dst_ptrs_as, CUDA_R_16F, ne01,
ne23,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
KerfuffleV2 marked this conversation as resolved.
Show resolved Hide resolved

// free device memory for pointers
CUDA_CHECK(cudaFree(src0_ptrs_as));
CUDA_CHECK(cudaFree(src1_ptrs_as));
CUDA_CHECK(cudaFree( dst_ptrs_as));

free(src0_ptrs);
free(src1_ptrs);
free( dst_ptrs);
}
#endif

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);

ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}

static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
Expand All @@ -7058,10 +7207,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
}
}

// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);

if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// KQ
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
Expand Down
4 changes: 4 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -16602,6 +16602,10 @@ static void ggml_compute_forward_cross_entropy_loss_back(
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
GGML_ASSERT(params);

if (tensor->op == GGML_OP_NONE) {
return;
}

#ifdef GGML_USE_CUBLAS
bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
if (skip_cpu) {
Expand Down