Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support running GPTQ 8-bit models in Marlin #4533

Merged
merged 9 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
Expand All @@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n);
int64_t size_n,
int64_t num_bits);
#endif

void squeezellm_gemm(
Expand Down
569 changes: 394 additions & 175 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu

Large diffs are not rendered by default.

10 changes: 3 additions & 7 deletions csrc/quantization/gptq_marlin/gptq_marlin.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ static constexpr int default_threads = 256;
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory

static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int min_thread_k = 32;

static constexpr int tile_size = 16;
static constexpr int max_par = 16;

static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit

template <typename T, int n>
struct Vec {
T elems[n];
Expand All @@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}

__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
Expand Down
152 changes: 90 additions & 62 deletions csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800

template <int const num_threads, bool const has_perm>
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
Expand All @@ -20,19 +20,22 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin

torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) {
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}

#else

template <int const num_threads, bool const has_perm>
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;

int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
Expand Down Expand Up @@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
sh_pipe_ptr += perm_size;
}

constexpr int tile_ints = tile_k_size / pack_factor;

constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads =
has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;

auto load_perm_to_shared = [&](int k_tile_id) {
Expand Down Expand Up @@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
reinterpret_cast<uint32_t const *>(sh_perm_ptr);

int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor_4bit;
int src_k_packed = src_k / pack_factor;

cp_async4_stream(
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
Expand All @@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int n_id = threadIdx.x % stage_n_threads;

int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor_4bit;
int first_k_packed = first_k / pack_factor;

cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
}
}

Expand All @@ -145,68 +149,84 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int cur_n = warp_id * 16 + tc_col;

constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;

int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);

uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);

uint32_t vals[pack_factor_4bit];
uint32_t vals[8];

if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];

uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor_4bit;
uint32_t src_k_pos = src_k % pack_factor;

uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf;
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;

uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf;
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;

vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
}

} else {

uint32_t b1_val_1 = sh_stage_int_ptr[cur_n];
uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n];

uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];

#pragma unroll
for (int i = 0; i < 2; i++) {
int cur_elem = tc_row + tc_offsets[i];
vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf;
vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}

#pragma unroll
for (int i = 2; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i] - 8;
vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf;
vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf;
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;

vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}

constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;

// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7};
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};

uint32_t res = 0;
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < pack_factor_4bit; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}

constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
out_ptr[out_offset + th_id * 4 + warp_id] = res;

out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};

uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}

out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};

auto start_pipes = [&](int k_tile_id, int n_tile_id) {
Expand Down Expand Up @@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,

} // namespace gptq_marlin

#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}

torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) {
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);

TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;

// Verify B
TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0),
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k,
", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
", size_k = ", size_k, ", pack_factor = ", pack_factor);
TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n);
Expand All @@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit},
options);
torch::Tensor out =
torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / pack_factor},
options);

// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
Expand All @@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);

if (has_perm) {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);

} else {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
}

return out;
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ def generate_greedy_logprobs(
return all_logprobs

def __del__(self):
del self.model
if self.model is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why needed? @alexm-nm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got occasional warnings about del self.model being None. Not necessary for correctness if it causes issues in CI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should touch this file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

del self.model
self.model = None
cleanup()


Expand Down
25 changes: 17 additions & 8 deletions tests/models/test_gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,24 @@
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())

MODELS = [
# act_order==False, group_size=channelwise
# 4-bit, act_order==False, group_size=channelwise
("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
# act_order==False, group_size=128
# 4-bit, act_order==False, group_size=128
("TheBloke/Llama-2-7B-GPTQ", "main"),

# act_order==True, group_size=128
# 4-bit, act_order==True, group_size=128
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"),
# act_order==True, group_size=64
# 4-bit, act_order==True, group_size=64
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
# act_order==True, group_size=32
# 4-bit, act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),

# 8-bit, act_order==True, group_size=channelwise
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"),
# 8-bit, act_order==True, group_size=128
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"),
# 8-bit, act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"),
]


Expand All @@ -63,10 +70,11 @@ def test_models(
gptq_marlin_model = vllm_runner(model_name=model_name,
revision=revision,
dtype=dtype,
quantization="marlin",
quantization="gptq",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexm-nm this test should have marlin for quantization

Also - is enforce_eager=True required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh must be a leftover from debug, good catch, will fix it in 30min

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, tests pass

max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
disable_custom_all_reduce=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make this cleaner, can we remove disable_custom_all_reduce?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works

enforce_eager=True)

gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
Expand All @@ -79,7 +87,8 @@ def test_models(
quantization="gptq",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
disable_custom_all_reduce=True,
enforce_eager=True)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
Expand Down
Loading
Loading