Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cuBLAS performance by using a memory pool #1094

Merged
merged 4 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ ifdef LLAMA_OPENBLAS
LDFLAGS += -lopenblas
endif
ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
OBJS += ggml-cuda.o
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
OBJS += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-linker -arch=native
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
nvcc -arch=native -c -o $@ $<
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@
endif
ifdef LLAMA_GPROF
CFLAGS += -pg
Expand Down
91 changes: 75 additions & 16 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <stdint.h>
#include <stdio.h>
#include <cuda_fp16.h>
#include <atomic>
#include "ggml-cuda.h"

typedef uint16_t ggml_fp16_t;
Expand Down Expand Up @@ -35,8 +37,6 @@ typedef struct {
} block_q4_3;
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");



static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
const block_q4_0 * x = (const block_q4_0 *) vx;

Expand Down Expand Up @@ -131,24 +131,83 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
}
}

extern "C" {
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
}

__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
}

void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
}

void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_3;
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
}

// lock-free, thread safe buffer pool for cuda
#define MAX_CUDA_BUFFERS 16
struct cuda_buffer {
std::atomic_uintptr_t ptr { 0 };
size_t size { 0 };
};

static cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS];

void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
struct cuda_buffer * b = &cuda_buffer_pool[i];
if (b->size >= size) {
uintptr_t ptr = atomic_load(&b->ptr);
if (ptr) {
if (std::atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this has an ABA problem? The scenario I'm thinking of goes like this:

  • we get preempted just before the CAS and lose the race to ptr to another thread
  • this other thread eventually frees ptr with cudaFree()
  • some other thread calls cudaMalloc() with a smaller size and get the same pointer as ptr (i.e., it is equal to it as an integer), then frees it into the same pool slot we are trying to use
  • we wake up, do the CAS (which succeeds because the new pointer is equal to the old one as integer) and start using the pointer with the wrong size

This is highly unlikely to happen in practice, but I think is technically possible, unless CUDART never returns the same pointer twice from cudaMalloc().

*actual_size = b->size;
return (void *) ptr;
}
}
}
}

__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
void * ptr;
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
*actual_size = size;
return ptr;
}

void ggml_cuda_pool_free(void * ptr, size_t size) {
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
struct cuda_buffer * b = &cuda_buffer_pool[i];
uintptr_t p = std::atomic_load(&b->ptr);
if (p == 0) {
if (std::atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) {
b->size = size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this introduces a race condition: another thread can observe a non-nullptr pointer with a wrong size. E.g., consider this execution history of two threads T0 and T1:

  • T0 calls ggml_cuda_pool_malloc(LARGE_SIZE). The pool is empty, so T0 calls cudaMalloc(LARGE_SIZE) and gets the resulting pointer pLarge.
  • T1 calls ggml_cuda_pool_malloc(SMALL_SIZE). The pool is again empty, to T1 calls cudaMalloc(SMALL_SIZE) and gets the resulting pointer pSmall.
  • T0 calls ggml_cuda_pool_free(pLarge). cuda_buffer_pool[0]->ptr is NULL, so T0 makes an update: cuda_buffer_pool[0] = {.ptr = pLarge, .size = LARGE_SIZE}.
  • T0 calls ggml_cuda_pool_malloc(LARGE_SIZE). cuda_buffer_pool[0]->size >= LARGE_SIZE && cuda_buffer_pool[0]->ptr != nullptr, so T0 makes an update (cuda_buffer_pool[0] = {.ptr = nullptr, .size = LARGE_SIZE}) and gets pLarge.
  • HERE BE DRAGONS T1 calls ggml_cuda_pool_free(pSmall). cuda_buffer_pool[0]->ptr is nullptr, so T1 tries to make an update. It makes a successful CAS on line 188 (so that cuda_buffer_pool[0] = {.ptr = pSmall, .size = LARGE_SIZE}), but then gets preempted by the OS scheduler.
  • T0 calls ggml_cuda_pool_malloc(LARGE_SIZE). cuda_buffer_pool[0]->size >= LARGE_SIZE && cuda_buffer_pool[0]->ptr != nullptr, so it makes an (irrelevant) update and gets pSmall.

So now T0 thinks it has LARGE_SIZE bytes, while in fact it only has SMALL_SIZE bytes.

return;
}
}
}
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
CUDA_CHECK(cudaFree(ptr));
}

cublasHandle_t cublasH = NULL;
cudaStream_t cudaStream = NULL;

void ggml_init_cublas(void) {
if (cublasH == NULL) {
// create cublas handle, bind a stream
CUBLAS_CHECK(cublasCreate(&cublasH));

CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));

CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));

__host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_3;
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
}
}
31 changes: 31 additions & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>

#ifdef __cplusplus
extern "C" {
#endif

#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
exit(1); \
} \
} while (0)

#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)



extern cublasHandle_t cublasH;
extern cudaStream_t cudaStream;

void ggml_init_cublas(void);
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
void ggml_cuda_pool_free(void * ptr, size_t size);

void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
Expand Down
92 changes: 24 additions & 68 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -148,44 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) {
#elif defined(GGML_USE_OPENBLAS)
#include <cblas.h>
#elif defined(GGML_USE_CUBLAS)
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "ggml-cuda.h"

#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
exit(1); \
} \
} while (0)

#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)

static cublasHandle_t cublasH = NULL;
static cudaStream_t cudaStream = NULL;
static void init_cublas(void) {
if (cublasH == NULL) {
// create cublas handle, bind a stream
CUBLAS_CHECK(cublasCreate(&cublasH));

CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));

CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));

// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
}
}
#endif

#undef MIN
Expand Down Expand Up @@ -3720,7 +3683,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {

// initialize cuBLAS
#if defined(GGML_USE_CUBLAS)
init_cublas();
ggml_init_cublas();
#endif

is_first_call = false;
Expand Down Expand Up @@ -7566,18 +7529,16 @@ static void ggml_compute_forward_mul_mat_f32(
}

#if defined(GGML_USE_CUBLAS)
float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;

CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
size_t x_size, y_size, d_size;
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
#endif

for (int64_t i03 = 0; i03 < ne03; i03++) {
Expand Down Expand Up @@ -7614,9 +7575,9 @@ static void ggml_compute_forward_mul_mat_f32(
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
#endif
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);

Expand Down Expand Up @@ -7766,18 +7727,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
#if defined(GGML_USE_CUBLAS)
ggml_fp16_t * const wdata = params->wdata;

float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;

CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
size_t x_size, y_size, d_size;
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
#else
float * const wdata = params->wdata;
#endif
Expand Down Expand Up @@ -7844,9 +7803,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(

#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
#endif
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/

Expand Down Expand Up @@ -8014,20 +7973,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
}

#if defined(GGML_USE_CUBLAS)
float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
float *d_Q = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;

CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
size_t x_size, y_size, d_size, q_size;
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);

void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
if (type == GGML_TYPE_Q4_0) {
Expand Down Expand Up @@ -8100,10 +8056,10 @@ static void ggml_compute_forward_mul_mat_q_f32(

#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
CUDA_CHECK(cudaFree(d_Q));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
#endif
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);

Expand Down