Skip to content

Commit

Permalink
perf: accelerate gqa performance (flashinfer-ai#356)
Browse files Browse the repository at this point in the history
Changes:
1. Prefetch page indices (we have already done such optimization on
decode kernels, but not on append/prefill kernels which was used in
GQA).
2. Unlock 1x4 warp layout in
flashinfer-ai#322, we didn't enable
this because the binary size is too large, we should further reduce some
unnecessary template arguments.
3. Optimize `threadblock_sync_mdo_states` for efficient merging
attention states of multiple warps in a threadblock. Our previous
implementation assumes small shared memory size and interleaves shared
memory reads/writes with computations, which is not as efficient as a
bulk shared memory access.

After this PR, the GQA kernel execution time (on H100) for setting
`batch_size=128, seq_len=1024, num_qo_heads=32, num_kv_heads=4,
head_dim=128` was improved from 133us to 103us.
  • Loading branch information
yzh119 authored Jul 4, 2024
1 parent 2e64a65 commit e56ddad
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 52 deletions.
7 changes: 6 additions & 1 deletion include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,12 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz
if (avg_packed_qo_len > 64 && head_dim < 256) {
warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2)
} else {
warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1)
if (avg_packed_qo_len > 16) {
warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1)
} else {
// avg_packed_qo_len <= 16
warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1)
}
}
const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout);

Expand Down
96 changes: 60 additions & 36 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags
uint32_t num_warps_z) {
return ((num_frags_y < 4) || (num_frags_y == 4 && num_frags_z % 2 == 1) ||
(num_frags_y > 4 && num_frags_y % (2 * num_warps_x) != 0) ||
(num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 200));
(num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256));
}

/*!
Expand Down Expand Up @@ -207,30 +207,20 @@ template <bool produce_v, uint32_t num_warps_x, uint32_t num_warps_z, uint32_t n
__device__ __forceinline__ void page_produce_kv(
smem_t smem, uint32_t* smem_offset,
paged_kv_t<page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
const uint32_t packed_page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
const size_t* kv_offset, const uint32_t kv_len) {
constexpr SharedMemFillMode fill_mode =
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill;
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_warps = num_warps_x * num_warps_z;
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
const uint32_t warp_idx = get_warp_idx<num_warps_x, num_warps_z>(), lane_idx = threadIdx.x;
const uint32_t kv_head_idx = blockIdx.z;
uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8;
// NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps
static_assert(num_frags_z * 4 % num_warps_x == 0);
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps * i, page_iter,
entry_idx);
DType* gptr = produce_v
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DType>(),
last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DType>(),
last_indptr);
DType* gptr = produce_v ? paged_kv.data + paged_kv.kv_offset_delta() + kv_offset[i]
: paged_kv.data + kv_offset[i];
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
Expand Down Expand Up @@ -800,9 +790,21 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
const uint32_t lane_idx) {
// only necessary when blockDim.z > 1
if constexpr (num_warps_z > 1) {
float2* smem_md = (float2*)smem_workspace;
// o: [num_warps, warp_size, 8]
// md: [num_warps, num_frags_x, 2, warp_size, 2 (m/d)]
float2* smem_md = (float2*)(smem_workspace + num_frags_x * num_frags_y * num_warps_x *
num_warps_z * warp_size * 8);
// o: [num_warps, num_frags_x, num_frags_y, warp_size(32), 8]
// md: [num_warps, num_frags_x, 2, warp_size(32), 2 (m/d)]
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
vec_t<float, 8>::memcpy(
smem_workspace +
(((warp_idx * num_frags_x + fx) * num_frags_y + fy) * warp_size + lane_idx) * 8,
o_frag[fx][fy]);
}
}

#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
Expand Down Expand Up @@ -851,23 +853,22 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
}
}

__syncthreads();

// the following code saves shared memory usage.
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
vec_t<float, 8> o_new;
o_new.fill(0.f);
vec_t<float, 8>::memcpy(smem_workspace + (warp_idx * warp_size + lane_idx) * 8,
o_frag[fx][fy]);
__syncthreads();
#pragma unroll
for (uint32_t i = 0; i < num_warps_z; ++i) {
vec_t<float, 8> oi;
oi.load(smem_workspace +
((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * warp_size +
((((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * num_frags_x +
fx) *
num_frags_y +
fy) *
warp_size +
lane_idx) *
8);
#pragma unroll
Expand All @@ -876,7 +877,6 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
}
}
o_new.store(o_frag[fx][fy]);
__syncthreads();
}
}
}
Expand Down Expand Up @@ -1592,6 +1592,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
smem_t k_smem(smem + (num_warps_x * num_frags_x) * 16 * head_dim * sizeof(DTypeIn)),
v_smem(smem + (num_warps_x * num_frags_x + num_warps_z * num_frags_z) * 16 * head_dim *
sizeof(DTypeIn));
size_t kv_offset[num_frags_z * 4 / num_warps_x];

uint32_t k_smem_offset_r = smem_t::get_permuted_offset<channel_size_128b_in>(
get_warp_idx_z<num_warps_x, num_warps_z>() * num_frags_z * 16 + 8 * (lane_idx / 16) +
Expand All @@ -1605,13 +1606,22 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];

uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start;
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] =
page_iter < last_indptr
? paged_kv.get_k_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx,
entry_idx, (lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
}
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
last_indptr);
k_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
cp_async::commit_group();
page_produce_kv<true, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
last_indptr);
v_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
cp_async::commit_group();

const uint32_t num_iterations =
Expand All @@ -1631,8 +1641,20 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
: chunk_end - chunk_start) /
(16 * num_warps_z * num_frags_z);

#pragma unroll
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
packed_page_iter_base += 16 * num_warps_z * num_frags_z;
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] = page_iter < last_indptr
? paged_kv.get_k_elem_offset(
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
}
cp_async::wait_group<1>();
block.sync();

Expand Down Expand Up @@ -1677,11 +1699,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);

block.sync();
packed_page_iter_base += 16 * num_warps_z * num_frags_z;
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv,
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
last_indptr);
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
cp_async::commit_group();
cp_async::wait_group<1>();
block.sync();
Expand All @@ -1693,8 +1713,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
block.sync();
page_produce_kv<true, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv,
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
last_indptr);
chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
cp_async::commit_group();
}
cp_async::wait_group<0>();
Expand Down Expand Up @@ -1764,10 +1783,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
const uint_fastdiv group_size_fastdiv(group_size);
constexpr uint32_t num_frags_y = HEAD_DIM / 16;
WarpLayout warp_layout;
if (qo_len * group_size > 64 && HEAD_DIM < 256) {
int64_t unpacked_qo_len = qo_len * group_size;
if (unpacked_qo_len > 64 && HEAD_DIM < 256) {
warp_layout = WarpLayout::k4x1x2;
} else {
warp_layout = WarpLayout::k4x1x1;
if (unpacked_qo_len > 16) {
warp_layout = WarpLayout::k4x1x1;
} else {
warp_layout = WarpLayout::k1x4x1;
}
}

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
Expand Down
29 changes: 16 additions & 13 deletions include/flashinfer/attention/warp_layout.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace flashinfer {
enum class WarpLayout {
k4x1x2 = 0U,
k4x1x1 = 1U,
// k1x4x1 = 2U,
k1x4x1 = 2U,
};

template <WarpLayout warp_layout>
Expand All @@ -44,10 +44,10 @@ constexpr uint32_t get_num_warps_x<WarpLayout::k4x1x1>() {
return 4;
}

// template <>
// constexpr uint32_t get_num_warps_x<WarpLayout::k1x4x1>() {
// return 1;
// }
template <>
constexpr uint32_t get_num_warps_x<WarpLayout::k1x4x1>() {
return 1;
}

template <WarpLayout warp_layout>
constexpr uint32_t get_num_warps_z() {
Expand All @@ -64,10 +64,10 @@ constexpr uint32_t get_num_warps_z<WarpLayout::k4x1x1>() {
return 1;
}

// template <>
// constexpr uint32_t get_num_warps_z<WarpLayout::k1x4x1>() {
// return 4;
// }
template <>
constexpr uint32_t get_num_warps_z<WarpLayout::k1x4x1>() {
return 4;
}

template <WarpLayout warp_layout>
constexpr uint32_t get_num_frags_x() {
Expand All @@ -84,10 +84,10 @@ constexpr uint32_t get_num_frags_x<WarpLayout::k4x1x1>() {
return 1;
}

// template <>
// constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
// return 1;
// }
template <>
constexpr uint32_t get_num_frags_x<WarpLayout::k1x4x1>() {
return 1;
}

#define DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, ...) \
if (warp_layout == WarpLayout::k4x1x2) { \
Expand All @@ -96,6 +96,9 @@ constexpr uint32_t get_num_frags_x<WarpLayout::k4x1x1>() {
} else if (warp_layout == WarpLayout::k4x1x1) { \
constexpr WarpLayout WARP_LAYOUT = WarpLayout::k4x1x1; \
__VA_ARGS__ \
} else if (warp_layout == WarpLayout::k1x4x1) { \
constexpr WarpLayout WARP_LAYOUT = WarpLayout::k1x4x1; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported warp layout: " << int(warp_layout); \
Expand Down
2 changes: 1 addition & 1 deletion python/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
warp_layout_choice = [0, 1]
warp_layout_choice = [0, 1, 2]
insts = "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<page_storage, {warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(
Expand Down
2 changes: 1 addition & 1 deletion python/generate_batch_ragged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
warp_layout_choice = [0, 1]
warp_layout_choice = [0, 1, 2]
insts = "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{warp_layout}, {head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}, {idtype}>(
Expand Down

0 comments on commit e56ddad

Please sign in to comment.