Skip to content

Commit

Permalink
refacotr
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 5, 2023
1 parent 3c4a83e commit 44acb8b
Showing 1 changed file with 70 additions and 100 deletions.
170 changes: 70 additions & 100 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3418,23 +3418,20 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
const int mmq_x = MMQ_X_Q4_0_AMPERE;
const int mmq_y = MMQ_Y_Q4_0_AMPERE;
const int nwarps = NWARPS_Q4_0_AMPERE;

mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q4_0_PASCAL;
const int mmq_y = MMQ_Y_Q4_0_PASCAL;
const int nwarps = NWARPS_Q4_0_PASCAL;

mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_0_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q4_1_AMPERE 64
Expand All @@ -3457,23 +3454,20 @@ template <bool need_check> static __global__ void
const int mmq_x = MMQ_X_Q4_1_AMPERE;
const int mmq_y = MMQ_Y_Q4_1_AMPERE;
const int nwarps = NWARPS_Q4_1_AMPERE;

mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q4_1_PASCAL;
const int mmq_y = MMQ_Y_Q4_1_PASCAL;
const int nwarps = NWARPS_Q4_1_PASCAL;

mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_1_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q5_0_AMPERE 128
Expand All @@ -3492,23 +3486,20 @@ template <bool need_check> static __global__ void mul_mat_q5_0(
const int mmq_x = MMQ_X_Q5_0_AMPERE;
const int mmq_y = MMQ_Y_Q5_0_AMPERE;
const int nwarps = NWARPS_Q5_0_AMPERE;

mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q5_0_PASCAL;
const int mmq_y = MMQ_Y_Q5_0_PASCAL;
const int nwarps = NWARPS_Q5_0_PASCAL;

mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_0_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q5_1_AMPERE 128
Expand All @@ -3527,23 +3518,20 @@ template <bool need_check> static __global__ void mul_mat_q5_1(
const int mmq_x = MMQ_X_Q5_1_AMPERE;
const int mmq_y = MMQ_Y_Q5_1_AMPERE;
const int nwarps = NWARPS_Q5_1_AMPERE;

mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q5_1_PASCAL;
const int mmq_y = MMQ_Y_Q5_1_PASCAL;
const int nwarps = NWARPS_Q5_1_PASCAL;

mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_1_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q8_0_AMPERE 128
Expand All @@ -3562,23 +3550,20 @@ template <bool need_check> static __global__ void mul_mat_q8_0(
const int mmq_x = MMQ_X_Q8_0_AMPERE;
const int mmq_y = MMQ_Y_Q8_0_AMPERE;
const int nwarps = NWARPS_Q8_0_AMPERE;

mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q8_0_PASCAL;
const int mmq_y = MMQ_Y_Q8_0_PASCAL;
const int nwarps = NWARPS_Q8_0_PASCAL;

mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q8_0_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q2_K_AMPERE 64
Expand All @@ -3597,23 +3582,20 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
const int mmq_x = MMQ_X_Q2_K_AMPERE;
const int mmq_y = MMQ_Y_Q2_K_AMPERE;
const int nwarps = NWARPS_Q2_K_AMPERE;

mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q2_K_PASCAL;
const int mmq_y = MMQ_Y_Q2_K_PASCAL;
const int nwarps = NWARPS_Q2_K_PASCAL;

mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q2_K_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q3_K_AMPERE 128
Expand All @@ -3636,23 +3618,20 @@ template <bool need_check> static __global__ void
const int mmq_x = MMQ_X_Q3_K_AMPERE;
const int mmq_y = MMQ_Y_Q3_K_AMPERE;
const int nwarps = NWARPS_Q3_K_AMPERE;

mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q3_K_PASCAL;
const int mmq_y = MMQ_Y_Q3_K_PASCAL;
const int nwarps = NWARPS_Q3_K_PASCAL;

mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q3_K_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q4_K_AMPERE 64
Expand All @@ -3675,23 +3654,20 @@ template <bool need_check> static __global__ void
const int mmq_x = MMQ_X_Q4_K_AMPERE;
const int mmq_y = MMQ_Y_Q4_K_AMPERE;
const int nwarps = NWARPS_Q4_K_AMPERE;

mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q4_K_PASCAL;
const int mmq_y = MMQ_Y_Q4_K_PASCAL;
const int nwarps = NWARPS_Q4_K_PASCAL;

mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_K_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q5_K_AMPERE 64
Expand All @@ -3710,23 +3686,20 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
const int mmq_x = MMQ_X_Q5_K_AMPERE;
const int mmq_y = MMQ_Y_Q5_K_AMPERE;
const int nwarps = NWARPS_Q5_K_AMPERE;

mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q5_K_PASCAL;
const int mmq_y = MMQ_Y_Q5_K_PASCAL;
const int nwarps = NWARPS_Q5_K_PASCAL;

mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_K_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

#define MMQ_X_Q6_K_AMPERE 64
Expand All @@ -3749,23 +3722,20 @@ template <bool need_check> static __global__ void
const int mmq_x = MMQ_X_Q6_K_AMPERE;
const int mmq_y = MMQ_Y_Q6_K_AMPERE;
const int nwarps = NWARPS_Q6_K_AMPERE;

mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);

#elif __CUDA_ARCH__ >= MIN_CC_DP4A
const int mmq_x = MMQ_X_Q6_K_PASCAL;
const int mmq_y = MMQ_Y_Q6_K_PASCAL;
const int nwarps = NWARPS_Q6_K_PASCAL;

mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q6_K_q8_1_mul_mat;
const int mmq_x = -1;
const int mmq_y = -1;
const int nwarps = -1;
assert(false);
#endif // __CUDA_ARCH__ >= CC_TURING

mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y);
}

template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
Expand Down

0 comments on commit 44acb8b

Please sign in to comment.