Skip to content

Commit

Permalink
bugfix: fix batch_prefill.cu in AOT mode after #554 (#559)
Browse files Browse the repository at this point in the history
#554 didn't update the `batch_prefill.cu` (which was used in AOT mode)
according to the API change.
This PR fixes the issue.

cc @abcdabcd987
  • Loading branch information
yzh119 authored Oct 26, 2024
1 parent 6227562 commit ea86f81
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions flashinfer-aot/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
return plan_info.ToVector();
}

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
torch::Tensor BatchPrefillWithRaggedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
bool return_lse) {
std::optional<torch::Tensor> maybe_lse) {
PrefillPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
Expand All @@ -98,10 +98,11 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
int64_t nnz_qo = q.size(0);
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}

void* float_buffer_ptr = float_workspace_buffer.data_ptr();
Expand Down Expand Up @@ -140,7 +141,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
: nullptr,
/*q_offset=*/nullptr,
/*k_rope_pos_offset=*/nullptr, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h,
kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale,
rope_theta);
Expand Down Expand Up @@ -187,22 +188,18 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
return o;
}

std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
torch::Tensor BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse) {
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
PrefillPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
Expand All @@ -221,10 +218,11 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto o = torch::empty_like(q, q.options());
int64_t nnz_qo = q.size(0);
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
if (maybe_lse) {
const auto& lse = *maybe_lse;
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
}

void* float_buffer_ptr = static_cast<void*>(float_workspace_buffer.data_ptr());
Expand Down Expand Up @@ -277,7 +275,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
maybe_qk_indptr.has_value() ? static_cast<IdType*>(maybe_qk_indptr->data_ptr())
: nullptr,
/*q_offset=*/nullptr, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
logits_soft_cap, sm_scale, rope_scale, rope_theta);

Expand Down Expand Up @@ -323,9 +321,5 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
});
});

if (return_lse) {
return {o, lse};
} else {
return {o};
}
return o;
}

0 comments on commit ea86f81

Please sign in to comment.