From d3ab1c77faed3f4182d45971330cf2d0d95059da Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Thu, 2 May 2024 12:56:22 -0400 Subject: [PATCH] [Kernel] Support running GPTQ 8-bit models in Marlin (#4533) --- csrc/ops.h | 4 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 552 ++++++++++++------ csrc/quantization/gptq_marlin/gptq_marlin.cuh | 8 +- .../gptq_marlin/gptq_marlin_repack.cu | 152 +++-- tests/models/test_gptq_marlin.py | 13 +- vllm/_custom_ops.py | 14 +- .../layers/quantization/gptq_marlin.py | 134 ++--- 7 files changed, 553 insertions(+), 324 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 04b97d1784cd2..8ae052427052f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, @@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack( torch::Tensor &b_q_weight, torch::Tensor &perm, int64_t size_k, - int64_t size_n); + int64_t size_n, + int64_t num_bits); #endif void squeezellm_gemm( diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9902f55167d89..fd0837f0cb39c 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int4 *__restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template = 8.0"); return torch::empty({1, 1}); @@ -114,11 +115,21 @@ template __device__ inline int lop3(int a, int b, int c) { return res; } +// Constructs destination register by taking bytes from 2 sources (based on mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +__device__ inline FragB dequant_4bit(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) { return frag_b; } +__device__ inline FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { @@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2, frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float *c, FragS &s) { + __half *s_ptr = reinterpret_cast<__half *>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int *lock, int count) { if (threadIdx.x == 0) { @@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); @@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + FragB frag_b0; + FragB frag_b1; + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); - FragB frag_b0 = dequant(b_quant); + } else { + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped portioning + // finally have to globally reduce over the results. As the striped partitioning // minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { @@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto write = [&](int idx, float c0, float c1, FragS &s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { res = __hmul2(res, s[0]); } ((half2 *)sh)[idx] = res; }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at // index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll @@ -1070,23 +1145,63 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (num_bits == 8) { if (s_sh_wr_pred) { - cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (num_bits == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } } } @@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } - // if (blockIdx.x == 0 && threadIdx.x == 0) { - // printf("Move\n"); - // } start_pipes(); } } } } -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ +#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ - Marlin, \ + Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ + Marlin \ <<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ prob_k, locks); \ @@ -1158,28 +1270,92 @@ typedef struct { int num_threads; } thread_config_t; -thread_config_t small_batch_thread_configs[] = { +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N + {64, 256, 256}, // Default (max cache usage) + {64, 128, 128}, // Reduce N, reduce warps + {128, 64, 128}, // Reduce N more, but increase K + }; -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority +int get_scales_cache_size(thread_config_t const &th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 128, 128}, // Reduce N 2X, same K - // {128, 64, 128}, // Reduce N 4X, increase K 2X -}; + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; -bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, - int prob_k) { + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const &th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, return false; } + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + return true; } -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - - // TODO: Enable if needed after some more testing - if (prob_m <= 0) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + for (auto th_config : thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; } } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } + printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM " + "GPU cache. This may " + "hurt performance. Consider upgrading your GPU.\n"); + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } - return thread_config_t{-1, -1, -1}; + return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ +#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, - void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, - void *workspace, bool has_act_order, bool is_k_full, - int num_groups, int group_size, int dev = 0, - cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, - int sms = -1, int max_par = 16) { + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, + void *g_idx, void *perm, void *a_tmp, int prob_m, + int prob_n, int prob_k, void *workspace, int num_bits, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, int max_par) { + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, TORCH_CHECK(max_shared_mem > 0); // Set thread config - thread_config_t th_config; + exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config - th_config = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; } else { // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; @@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, } // Main loop - for (int i = 0; i < tot_m_blocks; i += 4) { + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; int par = 1; - if (thread_m_blocks > 4) { + if (thread_m_blocks > exec_cfg.max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * thread_m_blocks - pad) / 64; + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; } // Define kernel configurations if (false) { } - CALL_IF(16, 4, 256) - CALL_IF(8, 8, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) + CALL_IF(4, 32, 2, 256) + CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 4, 128) + CALL_IF(4, 4, 8, 128) + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &workspace, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + // Verify A - TORCH_CHECK(a.size(0) == size_m, - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - TORCH_CHECK(a.size(1) == size_k, - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(gptq_marlin::tile_size)); + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(gptq_marlin::tile_size)); - TORCH_CHECK( - b_q_weight.size(1) % gptq_marlin::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(gptq_marlin::tile_size)); - int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * - gptq_marlin::pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); // Verify device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); @@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, // Verify g_idx and perm TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || (g_idx.size(0) == size_k && perm.size(0) == size_k), - "Unexpected g_idx.size(0) = " + str(g_idx.size(0)) + - " and perm.size(0) = " + str(perm.size(0)) + - ", where size_k = " + str(size_k)); + "Unexpected g_idx.size(0) = ", g_idx.size(0), + " and perm.size(0) = ", perm.size(0), + ", where size_k = ", size_k); // Detect groupsize and act_order int num_groups = -1; @@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, if (has_act_order) { if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, - "size_k = " + str(size_k) + - ", is not divisible by num_groups = " + str(num_groups)); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); group_size = size_k / num_groups; } else { group_size = 0; @@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, } else { if (num_groups > 1) { - TORCH_CHECK(size_k % num_groups == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); group_size = size_k / num_groups; } else { group_size = -1; @@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, } // Verify workspace size - TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(gptq_marlin::min_thread_n)); + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - gptq_marlin::marlin_cuda( + gptq_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups, - group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, - sms, gptq_marlin::max_par); + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); return c; } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index 8cfce6b2575d5..35ea48aaba310 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; -static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit - template struct Vec { T elems[n]; @@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool "r"(smem), "l"(glob_ptr), "n"(BYTES)); } -__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("{\n" - " .reg .b64 p;\n" - " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" - " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index fa45ce68a0c77..0d3da6240dbca 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -template +template __global__ void marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const *__restrict__ perm_ptr, @@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } // namespace gptq_marlin torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n) { + int64_t size_k, int64_t size_n, + int64_t num_bits) { TORCH_CHECK_NOT_IMPLEMENTED( false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, #else -template +template __global__ void marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const *__restrict__ perm_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + int k_tiles = size_k / tile_k_size; int n_tiles = size_n / tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); @@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, sh_pipe_ptr += perm_size; } + constexpr int tile_ints = tile_k_size / pack_factor; + constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = - has_perm ? tile_k_size : tile_k_size / pack_factor_4bit; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; auto load_perm_to_shared = [&](int k_tile_id) { @@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; - int src_k_packed = src_k / pack_factor_4bit; + int src_k_packed = src_k / pack_factor; - cp_async4_stream( + cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); @@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int n_id = threadIdx.x % stage_n_threads; int first_k = k_tile_id * tile_k_size; - int first_k_packed = first_k / pack_factor_4bit; + int first_k_packed = first_k / pack_factor; - cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + - first_n + (n_id * 4)]))); + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); } } @@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int cur_n = warp_id * 16 + tc_col; constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - uint32_t vals[pack_factor_4bit]; + uint32_t vals[8]; if constexpr (has_perm) { for (int i = 0; i < 4; i++) { int k_idx = tc_row + tc_offsets[i]; uint32_t src_k = sh_perm_int_ptr[k_idx]; - uint32_t src_k_pos = src_k % pack_factor_4bit; + uint32_t src_k_pos = src_k % pack_factor; uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; - uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; - uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; vals[i] = b1_cur_val; vals[4 + i] = b2_cur_val; @@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } else { - uint32_t b1_val_1 = sh_stage_int_ptr[cur_n]; - uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n]; - - uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8]; - uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8]; + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; #pragma unroll - for (int i = 0; i < 2; i++) { - int cur_elem = tc_row + tc_offsets[i]; - vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf; - vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf; + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } #pragma unroll - for (int i = 2; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i] - 8; - vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf; - vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf; + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; } } + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - uint32_t res = 0; + uint32_t res = 0; #pragma unroll - for (int i = 0; i < pack_factor_4bit; i++) { - res |= vals[pack_idx[i]] << (i * 4); - } + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + out_ptr[out_offset + th_id * 4 + warp_id] = res; - out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { @@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } // namespace gptq_marlin +#define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n) { + int64_t size_k, int64_t size_n, + int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + // Verify B - TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0), + TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, - ", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit); + ", size_k = ", size_k, ", pack_factor = ", pack_factor); TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n); @@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, auto options = torch::TensorOptions() .dtype(b_q_weight.dtype()) .device(b_q_weight.device()); - torch::Tensor out = torch::empty( - {size_k / gptq_marlin::tile_size, - size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit}, - options); + torch::Tensor out = + torch::empty({size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / pack_factor}, + options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; @@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); - if (has_perm) { - cudaFuncSetAttribute( - gptq_marlin::marlin_repack_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_mem); - gptq_marlin::marlin_repack_kernel - <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); - - } else { - cudaFuncSetAttribute( - gptq_marlin::marlin_repack_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_mem); - gptq_marlin::marlin_repack_kernel - <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); + if (false) { + } + CALL_IF(4, false) + CALL_IF(4, true) + CALL_IF(8, false) + CALL_IF(8, true) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", has_perm = ", has_perm); } return out; diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index dc027697ffd4d..4d73843f970c4 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -39,6 +39,13 @@ ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"), # act_order==True, group_size=32 ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"), + + # 8-bit, act_order==True, group_size=channelwise + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"), + # 8-bit, act_order==True, group_size=128 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"), + # 8-bit, act_order==True, group_size=32 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"), ] @@ -65,8 +72,7 @@ def test_models( dtype=dtype, quantization="marlin", max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1, - disable_custom_all_reduce=True) + tensor_parallel_size=1) gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) @@ -78,8 +84,7 @@ def test_models( dtype=dtype, quantization="gptq", max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1, - disable_custom_all_reduce=True) + tensor_parallel_size=1) gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4af8b09b1e16c..3faed5ea85307 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n) + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int, + perm: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, size_k: int, is_k_full: bool) -> torch.Tensor: return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, size_m, size_n, size_k, - is_k_full) + workspace, num_bits, size_m, size_n, + size_k, is_k_full) # fp8 diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index efbffa0878c4b..e2464008a875f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional -import numpy import torch from torch.nn.parameter import Parameter @@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 -GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4] +GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_SYM = [True] -# Precompute permutations for Marlin weight and scale shuffling -# -# Marlin works on [16,64] tiles. The goal of the permutations -# is to reorder the weight data so that it is compatible -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the -# kernel will get the data as it is needed for tensor-core -# (without the need to use ldmatrix instructions) -def _get_perms(): - perm = [] - for i in range(32): - perm1 = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm) - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore - perm = torch.from_numpy(perm) +# Permutations for Marlin scale shuffling +def get_scale_perms(num_bits): scale_perm = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) @@ -59,23 +30,21 @@ def _get_perms(): for i in range(4): scale_perm_single.extend( [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -_perm, _scale_perm, _scale_perm_single = _get_perms() + return scale_perm, scale_perm_single def get_pack_factor(num_bits): - assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, ( - f"Unsupported num_bits = {num_bits}") + assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" return 32 // num_bits -def marlin_permute_scales(s, size_k, size_n, group_size): +def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): + scale_perm, scale_perm_single = get_scale_perms(num_bits) if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: - s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s @@ -279,13 +248,15 @@ def create_weights( requires_grad=False, ) set_weight_attrs( - qweight, { + qweight, + { **extra_weight_attrs, "input_dim": 0, "output_dim": 1, "packed_dim": 0, "pack_factor": self.quant_config.pack_factor, - }) + }, + ) # Activation order g_idx = Parameter( @@ -296,10 +267,13 @@ def create_weights( requires_grad=False, ) # Ignore warning from fused linear layers such as QKVParallelLinear. - set_weight_attrs(g_idx, { - **extra_weight_attrs, "input_dim": 0, - "ignore_warning": True - }) + set_weight_attrs( + g_idx, + { + **extra_weight_attrs, "input_dim": 0, + "ignore_warning": True + }, + ) g_idx_sort_indices = Parameter( torch.empty( @@ -320,29 +294,34 @@ def create_weights( requires_grad=False, ) set_weight_attrs( - scales, { + scales, + { **extra_weight_attrs, "input_dim": scales_and_zp_input_dim, "output_dim": 1, - }) + }, + ) # Quantized zero-points qzeros = Parameter( - torch.empty(scales_and_zp_size, - output_size_per_partition // - self.quant_config.pack_factor, - dtype=torch.int32, - device="meta"), + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + device="meta", + ), requires_grad=False, ) set_weight_attrs( - qzeros, { + qzeros, + { **extra_weight_attrs, "input_dim": scales_and_zp_input_dim, "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - }) + }, + ) # Allocate marlin workspace max_workspace_size = ( @@ -405,13 +384,14 @@ def replace_tensor(name, new_t): else: # Reset g_idx related tensors - layer.g_idx = Parameter(torch.empty(0, - dtype=torch.int, - device=cur_device), - requires_grad=False) - layer.g_idx_sort_indices = Parameter(torch.empty( - 0, dtype=torch.int, device=cur_device), - requires_grad=False) + layer.g_idx = Parameter( + torch.empty(0, dtype=torch.int, device=cur_device), + requires_grad=False, + ) + layer.g_idx_sort_indices = Parameter( + torch.empty(0, dtype=torch.int, device=cur_device), + requires_grad=False, + ) # Repack weights marlin_qweight = ops.gptq_marlin_repack( @@ -419,6 +399,7 @@ def replace_tensor(name, new_t): layer.g_idx_sort_indices, part_size_k, part_size_n, + self.quant_config.weight_bits, ) replace_tensor("qweight", marlin_qweight) @@ -428,15 +409,28 @@ def replace_tensor(name, new_t): if self.quant_config.desc_act: scales_size_k = full_size_k - marlin_scales = marlin_permute_scales(layer.scales, scales_size_k, - scales_size_n, - self.quant_config.group_size) + marlin_scales = marlin_permute_scales( + layer.scales, + scales_size_k, + scales_size_n, + self.quant_config.group_size, + self.quant_config.weight_bits, + ) replace_tensor("scales", marlin_scales) - output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales, - layer.g_idx, layer.g_idx_sort_indices, - layer.workspace, size_m, part_size_n, - part_size_k, layer.is_k_full) + output = ops.gptq_marlin_gemm( + reshaped_x, + layer.qweight, + layer.scales, + layer.g_idx, + layer.g_idx_sort_indices, + layer.workspace, + self.quant_config.weight_bits, + size_m, + part_size_n, + part_size_k, + layer.is_k_full, + ) if bias is not None: output.add_(bias) # In-place add