Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: faster dequantize kernels for Q4_0 and Q4_1 #4938

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,61 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
#endif // GGML_CUDA_F16
}

template<typename dst_t>
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {

const int i = blockIdx.x;

// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
if (ib >= nb32) {
return;
}

dst_t * y = yy + 256*i + 32*ir + 4*il;

const block_q4_0 * x = (const block_q4_0 *)vx + ib;
const float d = __half2float(x->d);
const float dm = -8*d;

const uint8_t * q = x->qs + 4*il;

for (int l = 0; l < 4; ++l) {
y[l+ 0] = d * (q[l] & 0xF) + dm;
y[l+16] = d * (q[l] >> 4) + dm;
}
}

template<typename dst_t>
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {

const int i = blockIdx.x;

// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
if (ib >= nb32) {
return;
}

dst_t * y = yy + 256*i + 32*ir + 4*il;

const block_q4_1 * x = (const block_q4_1 *)vx + ib;
const float2 d = __half22float2(x->dm);

const uint8_t * q = x->qs + 4*il;

for (int l = 0; l < 4; ++l) {
y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
y[l+16] = d.x * (q[l] >> 4) + d.y;
}
}

//================================== k-quants

template<typename dst_t>
Expand Down Expand Up @@ -6253,6 +6308,20 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
#endif
}

template<typename dst_t>
static void dequantize_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}

template<typename dst_t>
static void dequantize_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
}

Comment on lines +6311 to +6324
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are named incorrectly. Should be:

  • dequantize_row_q4_0_cuda
  • dequantize_row_q4_1_cuda

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was me naming them like this because these kernels are not just applicable to row-wise dequantization.

template<typename dst_t>
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
Expand Down Expand Up @@ -6301,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
int id;
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
return dequantize_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
return dequantize_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
Expand Down Expand Up @@ -6338,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
return dequantize_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
return dequantize_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
Expand Down
Loading