-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SOTA 2-bit quants #4773
SOTA 2-bit quants #4773
Changes from all commits
4af2488
7ef6389
7b72318
d383f00
dd29610
1c96aa0
e211fad
065cc8c
06e6908
8240521
c19d0d0
fd42737
47ae9b8
61c0405
7db967e
5684d79
bad5f7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -477,6 +477,14 @@ typedef struct { | |
} block_q6_K; | ||
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); | ||
|
||
#define QR2_XXS 8 | ||
#define QI2_XXS (QK_K / (4*QR2_XXS)) | ||
typedef struct { | ||
half d; | ||
uint16_t qs[QK_K/8]; | ||
} block_iq2_xxs; | ||
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); | ||
|
||
#define WARP_SIZE 32 | ||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses | ||
|
||
|
@@ -1292,6 +1300,128 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t | |
#endif | ||
} | ||
|
||
static const __device__ uint64_t kgrid_iq2xxs[256] = { | ||
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, | ||
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, | ||
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, | ||
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, | ||
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, | ||
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, | ||
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, | ||
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, | ||
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, | ||
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, | ||
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, | ||
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, | ||
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, | ||
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, | ||
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, | ||
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, | ||
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, | ||
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, | ||
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, | ||
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, | ||
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, | ||
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, | ||
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, | ||
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, | ||
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, | ||
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, | ||
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, | ||
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, | ||
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, | ||
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, | ||
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, | ||
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, | ||
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, | ||
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, | ||
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, | ||
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, | ||
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, | ||
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, | ||
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, | ||
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, | ||
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, | ||
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, | ||
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, | ||
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, | ||
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, | ||
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, | ||
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, | ||
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, | ||
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, | ||
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, | ||
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, | ||
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, | ||
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, | ||
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, | ||
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, | ||
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, | ||
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, | ||
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, | ||
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, | ||
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, | ||
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, | ||
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, | ||
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, | ||
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, | ||
}; | ||
|
||
static const __device__ uint8_t ksigns_iq2xs[128] = { | ||
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, | ||
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, | ||
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, | ||
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, | ||
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, | ||
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, | ||
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, | ||
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, | ||
}; | ||
|
||
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; | ||
|
||
inline bool ggml_cuda_supports_mmq(enum ggml_type type) { | ||
switch (type) { | ||
case GGML_TYPE_Q4_0: | ||
case GGML_TYPE_Q4_1: | ||
case GGML_TYPE_Q5_0: | ||
case GGML_TYPE_Q5_1: | ||
case GGML_TYPE_Q8_0: | ||
case GGML_TYPE_Q2_K: | ||
case GGML_TYPE_Q3_K: | ||
case GGML_TYPE_Q4_K: | ||
case GGML_TYPE_Q5_K: | ||
case GGML_TYPE_Q6_K: | ||
return true; | ||
default: | ||
return false; | ||
} | ||
} | ||
|
||
template<typename dst_t> | ||
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { | ||
|
||
const int i = blockIdx.x; | ||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx; | ||
|
||
const int tid = threadIdx.x; | ||
#if QK_K == 256 | ||
const int il = tid/8; // 0...3 | ||
const int ib = tid%8; // 0...7 | ||
dst_t * y = yy + i*QK_K + 32*ib + 8*il; | ||
const uint16_t * q2 = x[i].qs + 4*ib; | ||
const uint8_t * aux8 = (const uint8_t *)q2; | ||
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]); | ||
const uint32_t aux32 = q2[2] | (q2[3] << 16); | ||
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f; | ||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; | ||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); | ||
#else | ||
assert(false); | ||
#endif | ||
|
||
} | ||
|
||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { | ||
|
||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); | ||
|
@@ -3825,6 +3955,55 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( | |
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); | ||
} | ||
|
||
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( | ||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { | ||
#if QK_K == 256 | ||
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; | ||
|
||
#if QR2_XXS == 8 | ||
const int ib32 = iqs; | ||
const uint16_t * q2 = bq2->qs + 4*ib32; | ||
const uint8_t * aux8 = (const uint8_t *)q2; | ||
const int8_t * q8 = bq8_1[ib32].qs; | ||
uint32_t aux32 = q2[2] | (q2[3] << 16); | ||
int sumi = 0; | ||
for (int l = 0; l < 4; ++l) { | ||
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]); | ||
const uint8_t signs = ksigns_iq2xs[aux32 & 127]; | ||
for (int j = 0; j < 8; ++j) { | ||
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); | ||
} | ||
q8 += 8; | ||
aux32 >>= 7; | ||
} | ||
const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f; | ||
return d * sumi; | ||
#else | ||
// iqs is 0...15 | ||
const int ib32 = iqs/2; | ||
const int il = iqs%2; | ||
const uint16_t * q2 = bq2->qs + 4*ib32; | ||
const uint8_t * aux8 = (const uint8_t *)q2; | ||
const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]); | ||
const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]); | ||
const uint32_t aux32 = q2[2] | (q2[3] << 16); | ||
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f; | ||
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127]; | ||
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127]; | ||
const int8_t * q8 = bq8_1[ib32].qs + 16*il; | ||
int sumi1 = 0, sumi2 = 0; | ||
for (int j = 0; j < 8; ++j) { | ||
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1); | ||
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1); | ||
} | ||
return d * (sumi1 + sumi2); | ||
#endif | ||
Comment on lines
+3981
to
+4000
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume this is for testing different versions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. The one currently active is slightly faster on my RTX-4080, but I left the other version behind (which was the initial implementation) just in case. You never know with all these different cards that are being supported. |
||
#else | ||
assert(false); | ||
return 0.f; | ||
#endif | ||
} | ||
|
||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps, | ||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot> | ||
static __device__ __forceinline__ void mul_mat_q( | ||
|
@@ -5664,6 +5843,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu | |
#endif | ||
} | ||
|
||
template<typename dst_t> | ||
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { | ||
const int nb = k / QK_K; | ||
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y); | ||
} | ||
|
||
ikawrakow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
template <typename src_t, typename dst_t> | ||
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { | ||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; | ||
|
@@ -5692,6 +5877,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { | |
return dequantize_row_q5_K_cuda; | ||
case GGML_TYPE_Q6_K: | ||
return dequantize_row_q6_K_cuda; | ||
case GGML_TYPE_IQ2_XXS: | ||
return dequantize_row_iq2_xxs_cuda; | ||
case GGML_TYPE_F32: | ||
return convert_unary_cuda<float>; | ||
default: | ||
|
@@ -5721,6 +5908,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { | |
return dequantize_row_q5_K_cuda; | ||
case GGML_TYPE_Q6_K: | ||
return dequantize_row_q6_K_cuda; | ||
case GGML_TYPE_IQ2_XXS: | ||
return dequantize_row_iq2_xxs_cuda; | ||
case GGML_TYPE_F16: | ||
return convert_unary_cuda<half>; | ||
default: | ||
|
@@ -5915,6 +6104,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * | |
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||
} | ||
|
||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||
GGML_ASSERT(ncols % QK_K == 0); | ||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||
const dim3 block_nums(block_num_y, 1, 1); | ||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1> | ||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||
} | ||
|
||
static void ggml_mul_mat_q4_0_q8_1_cuda( | ||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, | ||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { | ||
|
@@ -7407,6 +7605,7 @@ static int64_t get_row_rounding(ggml_type type) { | |
case GGML_TYPE_Q4_K: | ||
case GGML_TYPE_Q5_K: | ||
case GGML_TYPE_Q6_K: | ||
case GGML_TYPE_IQ2_XXS: | ||
return max_compute_capability >= CC_RDNA2 ? 128 : 64; | ||
default: | ||
GGML_ASSERT(false); | ||
|
@@ -7427,6 +7626,7 @@ static int64_t get_row_rounding(ggml_type type) { | |
case GGML_TYPE_Q3_K: | ||
case GGML_TYPE_Q4_K: | ||
case GGML_TYPE_Q5_K: | ||
case GGML_TYPE_IQ2_XXS: | ||
return max_compute_capability >= CC_VOLTA ? 128 : 64; | ||
case GGML_TYPE_Q6_K: | ||
return 64; | ||
|
@@ -7477,6 +7677,9 @@ static void ggml_cuda_op_mul_mat_vec_q( | |
case GGML_TYPE_Q6_K: | ||
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); | ||
break; | ||
case GGML_TYPE_IQ2_XXS: | ||
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); | ||
break; | ||
default: | ||
GGML_ASSERT(false); | ||
break; | ||
|
@@ -8693,6 +8896,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 | |
|
||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) | ||
|
||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type); | ||
|
||
// debug helpers | ||
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); | ||
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have the same effect as
__constant__
? In other words, does it actually put these values into constant memory? (Should be faster than if it is in global memory.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just tried. Replacing
with
makes it massively slower (108 t/s vs 155 t/s)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this, but did not do it (yet). Mainly because:
__device__
on CUDA,constexpr
on Metal, etc.), so one either needs to work with pre-processor trickery, or needs to define the actual content as a macro. I did not like both options too muchBut yes, absolutely, this is something one should consider.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. The compiler is probably not copying the data from constant memory to registers. So for frequently used data it's slower as long as there is no register spilling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I know about this. It is context and model dependent. For some models (e.g. Falcon-7B) the difference between quantizing and not quantizing the hidden state can be quite dramatic. This is why the
dequantize_mul_mat_vec
kernels are actually useful, and I'm somewhat surprised they have fallen out of favor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, the MMQ kernels are mostly useful for contexts in the few - few tens of tokens range. I know this is an important use case for stuff such as speculative sampling, but in my private repo I have an MMQ implementation based on plain vector dot products that outperforms MMQ for, say, up to 16 tokens. If one could extend this up, and/or extend the dequantize/cuBLAS performance superiority down to fewer tokens, the MMQ kernels become unnecessary. This is the main reason I'm kind of reluctant with those, especially considering the amount of code and compilation time increase each new MMQ kernel adds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally the MMQ kernels were intended for large matrices. But it later turned out that the FP32 cuBLAS GEMM does not actually use tensor cores and that FP16 cuBLAS GEMM is still faster for Volta or newer. Georgi then repurposed the MMQ kernels for small batch sizes by changing the tile sizes. They were never intended or optimized by me for this use case and in my testing they still perform worse than FP16 cuBLAS even for small batch sizes:
However, because you do not need to dequantize the weight matrix MMQ should still be more efficient in terms of VRAM. Also on Pascal/RDNA2 or older there are no tensor cores so it is also faster than cuBLAS GEMM by a factor of ~2.
I was thinking that you could extend MMVQ to allow for >1 y columns and probably get better performance than with MMQ/cuBLAS GEMM. Presumably this is very similar to what you have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But to follow up on your
__constant__
comment, I did copy the data into shared memory on Metal. This boosted TG from about 48 t/s to ~54 t/s, so 12.5% speedup. There it was easy to implement without changing any other kernel. I was thinking that one could gain some performance on CUDA too by copying the grid/sign data to shared memory, but I didn't see how to do it without changing the MMVQ template and touching every single dot product kernel, so left it as is for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how the shared memory on Metal works but with CUDA there are by default 32 shared memory banks with a size of 4 byte each. The values here are uint64 = 8 byte so unless you can ensure that the threads in a warp access different memory banks ((pointer % (32*4 bytes))/4 bytes needs to be different) I would expect you to get a lot of memory bank conflicts which drastically reduces the memory bandwidth.
From what I can tell you have not yet published any models using the new format so I cannot test this myself but with NVIDIA NSight Compute under the occupancy section you can see the number of registers needed per thread. If that number is not limiting occupancy I would not expect much of a performance gain from moving the values to shared memory (but doing that may still reduce cache evictions of other data so I could be wrong).