Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: JIT compilation #507

Merged
merged 65 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
6d333ba
upd
yzh119 Sep 2, 2024
c0797cd
wip
yzh119 Sep 3, 2024
96f34b0
wip
yzh119 Sep 13, 2024
6b721e6
upd
yzh119 Sep 19, 2024
f3bd765
upd
yzh119 Sep 19, 2024
41e810e
upd
yzh119 Sep 20, 2024
585a720
fix decode
yzh119 Sep 20, 2024
319f7f5
upd
yzh119 Sep 20, 2024
8254243
bugfix
yzh119 Sep 20, 2024
2663698
upd
yzh119 Sep 21, 2024
54dfcce
wip
yzh119 Sep 23, 2024
04288c2
upd
yzh119 Sep 23, 2024
b70e3e0
fix
yzh119 Sep 23, 2024
330075a
bugfix in prefill
yzh119 Sep 23, 2024
ab64084
upd
yzh119 Sep 24, 2024
b4a2eaf
bugfix
yzh119 Sep 24, 2024
65c142c
upd
yzh119 Sep 24, 2024
ebfeee0
remove unused code
yzh119 Sep 24, 2024
029ecbe
formatter
yzh119 Sep 24, 2024
2940033
rename handler to scheduler
yzh119 Sep 24, 2024
143515a
remove decode/prefill decl
yzh119 Sep 24, 2024
316d423
simplify setup.py
yzh119 Sep 24, 2024
e890b40
upd
yzh119 Sep 25, 2024
0c4ea17
upd
yzh119 Sep 25, 2024
bbdb49b
bugfix
yzh119 Sep 25, 2024
66296e4
fix sparse
yzh119 Sep 25, 2024
091608a
another set of bugfix
yzh119 Sep 25, 2024
6a10f4b
fix sliding window
yzh119 Sep 25, 2024
81dd8dc
rebase
yzh119 Sep 25, 2024
1a17ead
wip: fix src directory
yzh119 Sep 25, 2024
af488ce
fix generate dispatch inc
yzh119 Sep 25, 2024
d013658
formatter
yzh119 Sep 25, 2024
dd5a9e2
fix generation scripts
yzh119 Sep 25, 2024
0133f04
fix cmakes
yzh119 Sep 25, 2024
5c92b21
remove page_storage
yzh119 Sep 25, 2024
860db4f
handler to scheduler
yzh119 Sep 25, 2024
1ad1428
bugfix on logits_hook
yzh119 Sep 25, 2024
f860c79
bugfix
yzh119 Sep 25, 2024
c4f9bb0
fix gen single_decode_inst script
yzh119 Sep 25, 2024
e30c0cb
bugfix
yzh119 Sep 25, 2024
ec5ef97
a bunch of bugfix
yzh119 Sep 25, 2024
3648dcb
fix sign
yzh119 Sep 25, 2024
5033115
bugfix
yzh119 Sep 25, 2024
645bf18
upd
yzh119 Sep 25, 2024
2138d57
formatter
yzh119 Sep 25, 2024
752047f
upd
yzh119 Sep 25, 2024
172950c
bugfix
yzh119 Sep 25, 2024
6e5a0a6
formatter
yzh119 Sep 25, 2024
60316f0
fix initialization of params
yzh119 Sep 25, 2024
f555390
rename DTypeOut to DTypeO
yzh119 Sep 25, 2024
e2cfa89
formatter
yzh119 Sep 25, 2024
b5c36d0
remove unused include
yzh119 Sep 26, 2024
2f5f71e
upd
yzh119 Sep 27, 2024
7eac0ba
upd
yzh119 Sep 27, 2024
a53a3f8
upd
yzh119 Oct 1, 2024
78b9678
upd
yzh119 Oct 1, 2024
72a7a1f
upd
yzh119 Oct 3, 2024
768fa2b
upd
yzh119 Oct 3, 2024
922c4aa
upd
yzh119 Oct 3, 2024
47476a7
load aot ops if existed
yzh119 Oct 4, 2024
409f461
upd
yzh119 Oct 5, 2024
502826d
upd
yzh119 Oct 5, 2024
64f1918
upd
yzh119 Oct 6, 2024
8bb68de
tests passed
yzh119 Oct 7, 2024
3f42c03
trailing empty lines
yzh119 Oct 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ src/dispatch.inc
src/generated/
python/csrc/generated/
python/flashinfer/_build_meta.py
python/flashinfer/jit/aot_config.py
flashinfer-aot/csrc_aot/generated/

# Generated documentation files
docs/generated
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 360 files
242 changes: 114 additions & 128 deletions CMakeLists.txt

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ set(FLASHINFER_FASTDEQUANT_TEST ON)
set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
Expand Down
1 change: 1 addition & 0 deletions flashinfer-aot/3rdparty
12 changes: 12 additions & 0 deletions flashinfer-aot/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# sdist & wheel
include version.txt
recursive-include include *
recursive-include csrc *
recursive-include 3rdparty/cutlass *

# wheel-only
exclude flashinfer/_build_meta.py

# Unneeded files
prune */__pycache__
global-exclude *.so
1 change: 1 addition & 0 deletions flashinfer-aot/csrc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,25 @@

#include <flashinfer/activation.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

__device__ __forceinline__ float silu(const float& val) {
return val / (1.0f + __expf(-val));
}

__device__ __forceinline__ float gelu(const float& val) {
constexpr float kAlpha = M_SQRT1_2;
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
}

__device__ __forceinline__ float gelu_tanh(const float& val) {
const float cdf =
0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
return val * cdf;
}

void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
Expand All @@ -33,7 +47,7 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::silu_kernel>
flashinfer::activation::act_and_mul_kernel<c_type, silu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

Expand All @@ -51,7 +65,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_tanh_kernel>
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

Expand All @@ -69,7 +83,7 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_kernel>
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

Expand Down
205 changes: 205 additions & 0 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>

#include <flashinfer/attention/decode_params.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/variants.cuh>
#include <optional>

#include "pytorch_extension_utils.h"

namespace flashinfer {

template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream);

} // namespace flashinfer

std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
bool use_logits_soft_cap, unsigned int head_dim, torch::Tensor empty_q_data,
torch::Tensor empty_kv_data, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, torch::Tensor page_locked_int_workspace_buffer,
torch::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
indptr = indptr.to(torch::kCPU);

DecodePlanInfo plan_info;

using IdType = int32_t;
// check indptr has idtype int32
TORCH_CHECK(indptr.scalar_type() == torch::kInt32, "indptr must be int32");
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;

auto q_scalar_type = empty_q_data.scalar_type();
auto kv_scalar_type = empty_kv_data.scalar_type();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
using DTypeKV = kv_type;
using DTypeO = DTypeQ;
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
using AttentionVariant =
ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false,
/*use_sliding_window=*/true,
USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>;

cudaError_t status = DecodePlan<HEAD_DIM, POS_ENCODING_MODE, AttentionVariant>(
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, page_size, enable_cuda_graph,
/*stream=*/torch_current_stream);

TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
});

return plan_info.ToVector();
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
DecodePlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
bool paged_kv_defined = paged_kv_cache.has_value();
auto device = q.device();
int64_t batch_size = q.size(0);
int64_t num_qo_heads = q.size(1);
int64_t num_kv_heads, page_size;
if (paged_kv_defined) {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_kv_cache->size(2);
page_size = paged_kv_cache->size(3);
} else {
page_size = paged_kv_cache->size(2);
num_kv_heads = paged_kv_cache->size(3);
}
} else {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->size(1);
page_size = paged_k_cache->size(2);
} else {
page_size = paged_k_cache->size(1);
num_kv_heads = paged_k_cache->size(2);
}
}
uint32_t head_dim = q.size(2);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
torch::Tensor o = torch::empty_like(q);
torch::Tensor lse;
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32)));
}

TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");

void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());

using IdType = int32_t;
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;

// get q_scalar_type and kv_scalar_type
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type =
paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
using DTypeKV = kv_type;
using DTypeO = DTypeQ;
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_LOGITS_SOFT_CAP(logits_soft_cap > 0, USE_LOGITS_SOFT_CAP, [&] {
using ParamsT = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;
using AttentionVariant =
ComposedAttention<ParamsT, get_variant_code(/*use_custom_mask=*/false,
/*use_sliding_window=*/true,
USE_LOGITS_SOFT_CAP, /*use_alibi=*/false)>;

paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr),
static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr),
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()),
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap,
sm_scale, rope_scale, rope_theta);

DTypeO* tmp_v = nullptr;
float* tmp_s = nullptr;
params.request_indices =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset);
params.kv_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset);
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset);
params.kv_chunk_size_ptr =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_chunk_size_ptr_offset);
if (plan_info.split_kv) {
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset);
tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset);
if (plan_info.enable_cuda_graph) {
params.block_valid_mask =
GetPtrFromBaseOffset<bool>(int_buffer, plan_info.block_valid_mask_offset);
}
}
params.padded_batch_size = plan_info.padded_batch_size;

cudaError_t status =
flashinfer::BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, POS_ENCODING_MODE,
AttentionVariant>(
params, tmp_v, tmp_s, /*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
}
Loading